light pkg/api{client,server} refact (#2659)

* tests: don't run crowdsec if not necessary
* make listen_uri report the random port number when 0 is requested
* move apiserver.getTLSAuthType() -> csconfig.TLSCfg.GetAuthType()
* move apiserver.isEnrolled() -> apiclient.ApiClient.IsEnrolled()
* extract function apiserver.recoverFromPanic()
* simplify and move APIServer.GetTLSConfig() -> TLSCfg.GetTLSConfig()
* moved TLSCfg type to csconfig/tls.go
* APIServer.InitController(): early return / happy path
* extract function apiserver.newGinLogger()
* lapi tests
* update unit test
* lint (testify)
* lint (whitespace, variable names)
* update docker tests
This commit is contained in:
mmetc 2023-12-14 14:54:11 +01:00 committed by GitHub
parent 67cdf91f94
commit 89f704ef18
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 927 additions and 477 deletions

View file

@ -13,7 +13,7 @@ def test_no_agent(crowdsec, flavor):
'DISABLE_AGENT': 'true', 'DISABLE_AGENT': 'true',
} }
with crowdsec(flavor=flavor, environment=env) as cs: with crowdsec(flavor=flavor, environment=env) as cs:
cs.wait_for_log("*CrowdSec Local API listening on 0.0.0.0:8080*") cs.wait_for_log("*CrowdSec Local API listening on *:8080*")
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
res = cs.cont.exec_run('cscli lapi status') res = cs.cont.exec_run('cscli lapi status')
assert res.exit_code == 0 assert res.exit_code == 0
@ -37,7 +37,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory):
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
cs.wait_for_log([ cs.wait_for_log([
"*Generate local agent credentials*", "*Generate local agent credentials*",
"*CrowdSec Local API listening on 0.0.0.0:8080*", "*CrowdSec Local API listening on *:8080*",
]) ])
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
res = cs.cont.exec_run('cscli lapi status') res = cs.cont.exec_run('cscli lapi status')
@ -50,7 +50,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory):
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
cs.wait_for_log([ cs.wait_for_log([
"*Generate local agent credentials*", "*Generate local agent credentials*",
"*CrowdSec Local API listening on 0.0.0.0:8080*", "*CrowdSec Local API listening on *:8080*",
]) ])
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
res = cs.cont.exec_run('cscli lapi status') res = cs.cont.exec_run('cscli lapi status')
@ -65,7 +65,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory):
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
cs.wait_for_log([ cs.wait_for_log([
"*Generate local agent credentials*", "*Generate local agent credentials*",
"*CrowdSec Local API listening on 0.0.0.0:8080*", "*CrowdSec Local API listening on *:8080*",
]) ])
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
res = cs.cont.exec_run('cscli lapi status') res = cs.cont.exec_run('cscli lapi status')
@ -78,7 +78,7 @@ def test_machine_register(crowdsec, flavor, tmp_path_factory):
with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs: with crowdsec(flavor=flavor, environment=env, volumes=volumes) as cs:
cs.wait_for_log([ cs.wait_for_log([
"*Local agent already registered*", "*Local agent already registered*",
"*CrowdSec Local API listening on 0.0.0.0:8080*", "*CrowdSec Local API listening on *:8080*",
]) ])
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
res = cs.cont.exec_run('cscli lapi status') res = cs.cont.exec_run('cscli lapi status')

View file

@ -29,7 +29,7 @@ def test_split_lapi_agent(crowdsec, flavor):
cs_agent = crowdsec(name=agentname, environment=agent_env, flavor=flavor) cs_agent = crowdsec(name=agentname, environment=agent_env, flavor=flavor)
with cs_lapi as lapi: with cs_lapi as lapi:
lapi.wait_for_log("*CrowdSec Local API listening on 0.0.0.0:8080*") lapi.wait_for_log("*CrowdSec Local API listening on *:8080*")
lapi.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) lapi.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
with cs_agent as agent: with cs_agent as agent:
agent.wait_for_log("*Starting processing data*") agent.wait_for_log("*Starting processing data*")

View file

@ -11,7 +11,7 @@ def test_local_api_url_default(crowdsec, flavor):
"""Test LOCAL_API_URL (default)""" """Test LOCAL_API_URL (default)"""
with crowdsec(flavor=flavor) as cs: with crowdsec(flavor=flavor) as cs:
cs.wait_for_log([ cs.wait_for_log([
"*CrowdSec Local API listening on 0.0.0.0:8080*", "*CrowdSec Local API listening on *:8080*",
"*Starting processing data*" "*Starting processing data*"
]) ])
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
@ -29,7 +29,7 @@ def test_local_api_url(crowdsec, flavor):
} }
with crowdsec(flavor=flavor, environment=env) as cs: with crowdsec(flavor=flavor, environment=env) as cs:
cs.wait_for_log([ cs.wait_for_log([
"*CrowdSec Local API listening on 0.0.0.0:8080*", "*CrowdSec Local API listening on *:8080*",
"*Starting processing data*" "*Starting processing data*"
]) ])
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
@ -54,7 +54,7 @@ def test_local_api_url_ipv6(crowdsec, flavor):
with crowdsec(flavor=flavor, environment=env) as cs: with crowdsec(flavor=flavor, environment=env) as cs:
cs.wait_for_log([ cs.wait_for_log([
"*Starting processing data*", "*Starting processing data*",
"*CrowdSec Local API listening on 0.0.0.0:8080*", "*CrowdSec Local API listening on [::1]:8080*",
]) ])
cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK) cs.wait_for_http(8080, '/health', want_status=HTTPStatus.OK)
res = cs.cont.exec_run('cscli lapi status') res = cs.cont.exec_run('cscli lapi status')

View file

@ -23,7 +23,7 @@ def test_missing_key_file(crowdsec, flavor):
with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs:
# XXX: this message appears twice, is that normal? # XXX: this message appears twice, is that normal?
cs.wait_for_log("*while serving local API: missing TLS key file*") cs.wait_for_log("*while starting API server: missing TLS key file*")
def test_missing_cert_file(crowdsec, flavor): def test_missing_cert_file(crowdsec, flavor):
@ -35,7 +35,7 @@ def test_missing_cert_file(crowdsec, flavor):
} }
with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs: with crowdsec(flavor=flavor, environment=env, wait_status=Status.EXITED) as cs:
cs.wait_for_log("*while serving local API: missing TLS cert file*") cs.wait_for_log("*while starting API server: missing TLS cert file*")
def test_tls_missing_ca(crowdsec, flavor, certs_dir): def test_tls_missing_ca(crowdsec, flavor, certs_dir):
@ -174,7 +174,7 @@ def test_tls_split_lapi_agent(crowdsec, flavor, certs_dir):
with cs_lapi as lapi: with cs_lapi as lapi:
lapi.wait_for_log([ lapi.wait_for_log([
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*", "*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
"*CrowdSec Local API listening on 0.0.0.0:8080*" "*CrowdSec Local API listening on *:8080*"
]) ])
# TODO: wait_for_https # TODO: wait_for_https
lapi.wait_for_http(8080, '/health', want_status=None) lapi.wait_for_http(8080, '/health', want_status=None)
@ -225,7 +225,7 @@ def test_tls_mutual_split_lapi_agent(crowdsec, flavor, certs_dir):
with cs_lapi as lapi: with cs_lapi as lapi:
lapi.wait_for_log([ lapi.wait_for_log([
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*", "*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
"*CrowdSec Local API listening on 0.0.0.0:8080*" "*CrowdSec Local API listening on *:8080*"
]) ])
# TODO: wait_for_https # TODO: wait_for_https
lapi.wait_for_http(8080, '/health', want_status=None) lapi.wait_for_http(8080, '/health', want_status=None)
@ -276,7 +276,7 @@ def test_tls_client_ou(crowdsec, certs_dir):
with cs_lapi as lapi: with cs_lapi as lapi:
lapi.wait_for_log([ lapi.wait_for_log([
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*", "*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
"*CrowdSec Local API listening on 0.0.0.0:8080*" "*CrowdSec Local API listening on *:8080*"
]) ])
# TODO: wait_for_https # TODO: wait_for_https
lapi.wait_for_http(8080, '/health', want_status=None) lapi.wait_for_http(8080, '/health', want_status=None)
@ -306,7 +306,7 @@ def test_tls_client_ou(crowdsec, certs_dir):
with cs_lapi as lapi: with cs_lapi as lapi:
lapi.wait_for_log([ lapi.wait_for_log([
"*(tls) Client Auth Type set to VerifyClientCertIfGiven*", "*(tls) Client Auth Type set to VerifyClientCertIfGiven*",
"*CrowdSec Local API listening on 0.0.0.0:8080*" "*CrowdSec Local API listening on *:8080*"
]) ])
# TODO: wait_for_https # TODO: wait_for_https
lapi.wait_for_http(8080, '/health', want_status=None) lapi.wait_for_http(8080, '/health', want_status=None)

View file

@ -49,31 +49,37 @@ type AlertsDeleteOpts struct {
} }
func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) { func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) {
var addedIds models.AddAlertsResponse
var added_ids models.AddAlertsResponse
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &alerts) req, err := s.client.NewRequest(http.MethodPost, u, &alerts)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
resp, err := s.client.Do(ctx, req, &added_ids) resp, err := s.client.Do(ctx, req, &addedIds)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return &added_ids, resp, nil
return &addedIds, resp, nil
} }
// to demo query arguments // to demo query arguments
func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.GetAlertsResponse, *Response, error) { func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.GetAlertsResponse, *Response, error) {
var alerts models.GetAlertsResponse var (
var URI string alerts models.GetAlertsResponse
URI string
)
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
params, err := qs.Values(opts) params, err := qs.Values(opts)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("building query: %w", err) return nil, nil, fmt.Errorf("building query: %w", err)
} }
if len(params) > 0 { if len(params) > 0 {
URI = fmt.Sprintf("%s?%s", u, params.Encode()) URI = fmt.Sprintf("%s?%s", u, params.Encode())
} else { } else {
@ -89,16 +95,19 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.
if err != nil { if err != nil {
return nil, resp, fmt.Errorf("performing request: %w", err) return nil, resp, fmt.Errorf("performing request: %w", err)
} }
return &alerts, resp, nil return &alerts, resp, nil
} }
// to demo query arguments // to demo query arguments
func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*models.DeleteAlertsResponse, *Response, error) { func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*models.DeleteAlertsResponse, *Response, error) {
var alerts models.DeleteAlertsResponse var alerts models.DeleteAlertsResponse
params, err := qs.Values(opts) params, err := qs.Values(opts)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
u := fmt.Sprintf("%s/alerts?%s", s.client.URLPrefix, params.Encode()) u := fmt.Sprintf("%s/alerts?%s", s.client.URLPrefix, params.Encode())
req, err := s.client.NewRequest(http.MethodDelete, u, nil) req, err := s.client.NewRequest(http.MethodDelete, u, nil)
@ -110,12 +119,14 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return &alerts, resp, nil return &alerts, resp, nil
} }
func (s *AlertsService) DeleteOne(ctx context.Context, alert_id string) (*models.DeleteAlertsResponse, *Response, error) { func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) {
var alerts models.DeleteAlertsResponse var alerts models.DeleteAlertsResponse
u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alert_id)
u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID)
req, err := s.client.NewRequest(http.MethodDelete, u, nil) req, err := s.client.NewRequest(http.MethodDelete, u, nil)
if err != nil { if err != nil {
@ -126,11 +137,13 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alert_id string) (*models
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return &alerts, resp, nil return &alerts, resp, nil
} }
func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) { func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) {
var alert models.Alert var alert models.Alert
u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID) u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID)
req, err := s.client.NewRequest(http.MethodGet, u, nil) req, err := s.client.NewRequest(http.MethodGet, u, nil)
@ -142,5 +155,6 @@ func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return &alert, resp, nil return &alert, resp, nil
} }

View file

@ -26,10 +26,12 @@ func TestAlertsListAsMachine(t *testing.T) {
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
}) })
log.Printf("URL is %s", urlx) log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/") apiURL, err := url.Parse(urlx + "/")
if err != nil { if err != nil {
log.Fatalf("parsing api url: %s", apiURL) log.Fatalf("parsing api url: %s", apiURL)
} }
client, err := NewClient(&Config{ client, err := NewClient(&Config{
MachineID: "test_login", MachineID: "test_login",
Password: "test_password", Password: "test_password",
@ -199,6 +201,7 @@ func TestAlertsListAsMachine(t *testing.T) {
if err != nil { if err != nil {
log.Errorf("test Unable to list alerts : %+v", err) log.Errorf("test Unable to list alerts : %+v", err)
} }
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
} }
@ -209,14 +212,17 @@ func TestAlertsListAsMachine(t *testing.T) {
//this one doesn't //this one doesn't
filter := AlertsListOpts{IPEquals: new(string)} filter := AlertsListOpts{IPEquals: new(string)}
*filter.IPEquals = "1.2.3.4" *filter.IPEquals = "1.2.3.4"
alerts, resp, err = client.Alerts.List(context.Background(), filter) alerts, resp, err = client.Alerts.List(context.Background(), filter)
if err != nil { if err != nil {
log.Errorf("test Unable to list alerts : %+v", err) log.Errorf("test Unable to list alerts : %+v", err)
} }
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
} }
assert.Equal(t, 0, len(*alerts))
assert.Empty(t, *alerts)
} }
func TestAlertsGetAsMachine(t *testing.T) { func TestAlertsGetAsMachine(t *testing.T) {
@ -228,10 +234,12 @@ func TestAlertsGetAsMachine(t *testing.T) {
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
}) })
log.Printf("URL is %s", urlx) log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/") apiURL, err := url.Parse(urlx + "/")
if err != nil { if err != nil {
log.Fatalf("parsing api url: %s", apiURL) log.Fatalf("parsing api url: %s", apiURL)
} }
client, err := NewClient(&Config{ client, err := NewClient(&Config{
MachineID: "test_login", MachineID: "test_login",
Password: "test_password", Password: "test_password",
@ -390,6 +398,7 @@ func TestAlertsGetAsMachine(t *testing.T) {
alerts, resp, err := client.Alerts.GetByID(context.Background(), 1) alerts, resp, err := client.Alerts.GetByID(context.Background(), 1)
require.NoError(t, err) require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
} }
@ -401,7 +410,6 @@ func TestAlertsGetAsMachine(t *testing.T) {
//fail //fail
_, _, err = client.Alerts.GetByID(context.Background(), 2) _, _, err = client.Alerts.GetByID(context.Background(), 2)
assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found") assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found")
} }
func TestAlertsCreateAsMachine(t *testing.T) { func TestAlertsCreateAsMachine(t *testing.T) {
@ -418,10 +426,12 @@ func TestAlertsCreateAsMachine(t *testing.T) {
w.Write([]byte(`["3"]`)) w.Write([]byte(`["3"]`))
}) })
log.Printf("URL is %s", urlx) log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/") apiURL, err := url.Parse(urlx + "/")
if err != nil { if err != nil {
log.Fatalf("parsing api url: %s", apiURL) log.Fatalf("parsing api url: %s", apiURL)
} }
client, err := NewClient(&Config{ client, err := NewClient(&Config{
MachineID: "test_login", MachineID: "test_login",
Password: "test_password", Password: "test_password",
@ -435,13 +445,17 @@ func TestAlertsCreateAsMachine(t *testing.T) {
} }
defer teardown() defer teardown()
alert := models.AddAlertsRequest{} alert := models.AddAlertsRequest{}
alerts, resp, err := client.Alerts.Add(context.Background(), alert) alerts, resp, err := client.Alerts.Add(context.Background(), alert)
require.NoError(t, err) require.NoError(t, err)
expected := &models.AddAlertsResponse{"3"} expected := &models.AddAlertsResponse{"3"}
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
} }
if !reflect.DeepEqual(*alerts, *expected) { if !reflect.DeepEqual(*alerts, *expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
} }
@ -457,15 +471,17 @@ func TestAlertsDeleteAsMachine(t *testing.T) {
}) })
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "DELETE") testMethod(t, r, "DELETE")
assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message":"0 deleted alerts"}`)) w.Write([]byte(`{"message":"0 deleted alerts"}`))
}) })
log.Printf("URL is %s", urlx) log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/") apiURL, err := url.Parse(urlx + "/")
if err != nil { if err != nil {
log.Fatalf("parsing api url: %s", apiURL) log.Fatalf("parsing api url: %s", apiURL)
} }
client, err := NewClient(&Config{ client, err := NewClient(&Config{
MachineID: "test_login", MachineID: "test_login",
Password: "test_password", Password: "test_password",
@ -479,15 +495,18 @@ func TestAlertsDeleteAsMachine(t *testing.T) {
} }
defer teardown() defer teardown()
alert := AlertsDeleteOpts{IPEquals: new(string)} alert := AlertsDeleteOpts{IPEquals: new(string)}
*alert.IPEquals = "1.2.3.4" *alert.IPEquals = "1.2.3.4"
alerts, resp, err := client.Alerts.Delete(context.Background(), alert) alerts, resp, err := client.Alerts.Delete(context.Background(), alert)
require.NoError(t, err) require.NoError(t, err)
expected := &models.DeleteAlertsResponse{NbDeleted: ""} expected := &models.DeleteAlertsResponse{NbDeleted: ""}
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
} }
if !reflect.DeepEqual(*alerts, *expected) { if !reflect.DeepEqual(*alerts, *expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
} }

View file

@ -41,10 +41,13 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// specification of http.RoundTripper. // specification of http.RoundTripper.
req = cloneRequest(req) req = cloneRequest(req)
req.Header.Add("X-Api-Key", t.APIKey) req.Header.Add("X-Api-Key", t.APIKey)
if t.UserAgent != "" { if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent) req.Header.Add("User-Agent", t.UserAgent)
} }
log.Debugf("req-api: %s %s", req.Method, req.URL.String()) log.Debugf("req-api: %s %s", req.Method, req.URL.String())
if log.GetLevel() >= log.TraceLevel { if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpRequest(req, true) dump, _ := httputil.DumpRequest(req, true)
log.Tracef("auth-api request: %s", string(dump)) log.Tracef("auth-api request: %s", string(dump))
@ -55,6 +58,7 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err) log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)
return resp, err return resp, err
} }
if log.GetLevel() >= log.TraceLevel { if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpResponse(resp, true) dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("auth-api response: %s", string(dump)) log.Tracef("auth-api response: %s", string(dump))
@ -73,6 +77,7 @@ func (t *APIKeyTransport) transport() http.RoundTripper {
if t.Transport != nil { if t.Transport != nil {
return t.Transport return t.Transport
} }
return http.DefaultTransport return http.DefaultTransport
} }
@ -90,15 +95,19 @@ func (r retryRoundTripper) ShouldRetry(statusCode int) bool {
return true return true
} }
} }
return false return false
} }
func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var resp *http.Response var (
var err error resp *http.Response
err error
)
backoff := 0 backoff := 0
maxAttempts := r.maxAttempts maxAttempts := r.maxAttempts
if fflag.DisableHttpRetryBackoff.IsEnabled() { if fflag.DisableHttpRetryBackoff.IsEnabled() {
maxAttempts = 1 maxAttempts = 1
} }
@ -108,6 +117,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
if r.withBackOff { if r.withBackOff {
backoff += 10 + rand.Intn(20) backoff += 10 + rand.Intn(20)
} }
log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts) log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)
select { select {
case <-req.Context().Done(): case <-req.Context().Done():
@ -115,22 +125,28 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
case <-time.After(time.Duration(backoff) * time.Second): case <-time.After(time.Duration(backoff) * time.Second):
} }
} }
if r.onBeforeRequest != nil { if r.onBeforeRequest != nil {
r.onBeforeRequest(i) r.onBeforeRequest(i)
} }
clonedReq := cloneRequest(req) clonedReq := cloneRequest(req)
resp, err = r.next.RoundTrip(clonedReq) resp, err = r.next.RoundTrip(clonedReq)
if err != nil { if err != nil {
left := maxAttempts - i - 1 left := maxAttempts - i - 1
if left > 0 { if left > 0 {
log.Errorf("error while performing request: %s; %d retries left", err, left) log.Errorf("error while performing request: %s; %d retries left", err, left)
} }
continue continue
} }
if !r.ShouldRetry(resp.StatusCode) { if !r.ShouldRetry(resp.StatusCode) {
return resp, nil return resp, nil
} }
} }
return resp, err return resp, err
} }
@ -157,6 +173,7 @@ func (t *JWTTransport) refreshJwtToken() error {
if err != nil { if err != nil {
return fmt.Errorf("can't update scenario list: %s", err) return fmt.Errorf("can't update scenario list: %s", err)
} }
log.Debugf("scenarios list updated for '%s'", *t.MachineID) log.Debugf("scenarios list updated for '%s'", *t.MachineID)
} }
@ -175,14 +192,18 @@ func (t *JWTTransport) refreshJwtToken() error {
enc := json.NewEncoder(buf) enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false) enc.SetEscapeHTML(false)
err = enc.Encode(auth) err = enc.Encode(auth)
if err != nil { if err != nil {
return fmt.Errorf("could not encode jwt auth body: %w", err) return fmt.Errorf("could not encode jwt auth body: %w", err)
} }
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf) req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s%s/watchers/login", t.URL, t.VersionPrefix), buf)
if err != nil { if err != nil {
return fmt.Errorf("could not create request: %w", err) return fmt.Errorf("could not create request: %w", err)
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
client := &http.Client{ client := &http.Client{
Transport: &retryRoundTripper{ Transport: &retryRoundTripper{
next: http.DefaultTransport, next: http.DefaultTransport,
@ -191,9 +212,11 @@ func (t *JWTTransport) refreshJwtToken() error {
retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError}, retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},
}, },
} }
if t.UserAgent != "" { if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent) req.Header.Add("User-Agent", t.UserAgent)
} }
if log.GetLevel() >= log.TraceLevel { if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpRequest(req, true) dump, _ := httputil.DumpRequest(req, true)
log.Tracef("auth-jwt request: %s", string(dump)) log.Tracef("auth-jwt request: %s", string(dump))
@ -205,6 +228,7 @@ func (t *JWTTransport) refreshJwtToken() error {
if err != nil { if err != nil {
return fmt.Errorf("could not get jwt token: %w", err) return fmt.Errorf("could not get jwt token: %w", err)
} }
log.Debugf("auth-jwt : http %d", resp.StatusCode) log.Debugf("auth-jwt : http %d", resp.StatusCode)
if log.GetLevel() >= log.TraceLevel { if log.GetLevel() >= log.TraceLevel {
@ -226,12 +250,15 @@ func (t *JWTTransport) refreshJwtToken() error {
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return fmt.Errorf("unable to decode response: %w", err) return fmt.Errorf("unable to decode response: %w", err)
} }
if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil { if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
return fmt.Errorf("unable to parse jwt expiration: %w", err) return fmt.Errorf("unable to parse jwt expiration: %w", err)
} }
t.Token = response.Token t.Token = response.Token
log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String()) log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
return nil return nil
} }
@ -267,6 +294,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
dump, _ := httputil.DumpResponse(resp, true) dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err) log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
} }
if err != nil { if err != nil {
/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/ /*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
t.Token = "" t.Token = ""
@ -333,9 +361,12 @@ func cloneRequest(r *http.Request) *http.Request {
if r.Body != nil { if r.Body != nil {
var b bytes.Buffer var b bytes.Buffer
b.ReadFrom(r.Body) b.ReadFrom(r.Body)
r.Body = io.NopCloser(&b) r.Body = io.NopCloser(&b)
r2.Body = io.NopCloser(bytes.NewReader(b.Bytes())) r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
} }
return r2 return r2
} }

View file

@ -22,6 +22,7 @@ type enrollRequest struct {
func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) { func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error) {
u := fmt.Sprintf("%s/watchers", s.client.URLPrefix) u := fmt.Sprintf("%s/watchers", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodDelete, u, nil) req, err := s.client.NewRequest(http.MethodDelete, u, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -31,6 +32,7 @@ func (s *AuthService) UnregisterWatcher(ctx context.Context) (*Response, error)
if err != nil { if err != nil {
return resp, err return resp, err
} }
return resp, nil return resp, nil
} }
@ -46,6 +48,7 @@ func (s *AuthService) RegisterWatcher(ctx context.Context, registration models.W
if err != nil { if err != nil {
return resp, err return resp, err
} }
return resp, nil return resp, nil
} }
@ -53,6 +56,7 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch
var authResp models.WatcherAuthResponse var authResp models.WatcherAuthResponse
u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix) u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &auth) req, err := s.client.NewRequest(http.MethodPost, u, &auth)
if err != nil { if err != nil {
return authResp, nil, err return authResp, nil, err
@ -62,11 +66,13 @@ func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.Watch
if err != nil { if err != nil {
return authResp, resp, err return authResp, resp, err
} }
return authResp, resp, nil return authResp, resp, nil
} }
func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) { func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) {
u := fmt.Sprintf("%s/watchers/enroll", s.client.URLPrefix) u := fmt.Sprintf("%s/watchers/enroll", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite}) req, err := s.client.NewRequest(http.MethodPost, u, &enrollRequest{EnrollKey: enrollKey, Name: name, Tags: tags, Overwrite: overwrite})
if err != nil { if err != nil {
return nil, err return nil, err
@ -76,5 +82,6 @@ func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name
if err != nil { if err != nil {
return resp, err return resp, err
} }
return resp, nil return resp, nil
} }

View file

@ -35,6 +35,7 @@ func getLoginsForMockErrorCases() map[string]int {
func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
loginsForMockErrorCases := getLoginsForMockErrorCases() loginsForMockErrorCases := getLoginsForMockErrorCases()
mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST") testMethod(t, r, "POST")
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
@ -71,7 +72,6 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
* 400, 409, 500 => Error * 400, 409, 500 => Error
*/ */
func TestWatcherRegister(t *testing.T) { func TestWatcherRegister(t *testing.T) {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
mux, urlx, teardown := setup() mux, urlx, teardown := setup()
@ -79,6 +79,7 @@ func TestWatcherRegister(t *testing.T) {
//body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
initBasicMuxMock(t, mux, "/watchers") initBasicMuxMock(t, mux, "/watchers")
log.Printf("URL is %s", urlx) log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/") apiURL, err := url.Parse(urlx + "/")
if err != nil { if err != nil {
t.Fatalf("parsing api url: %s", apiURL) t.Fatalf("parsing api url: %s", apiURL)
@ -92,16 +93,19 @@ func TestWatcherRegister(t *testing.T) {
URL: apiURL, URL: apiURL,
VersionPrefix: "v1", VersionPrefix: "v1",
} }
client, err := RegisterClient(&clientconfig, &http.Client{}) client, err := RegisterClient(&clientconfig, &http.Client{})
if client == nil || err != nil { if client == nil || err != nil {
t.Fatalf("while registering client : %s", err) t.Fatalf("while registering client : %s", err)
} }
log.Printf("->%T", client) log.Printf("->%T", client)
// Testing error handling on Registration (400, 409, 500): should retrieve an error // Testing error handling on Registration (400, 409, 500): should retrieve an error
errorCodesToTest := [3]int{http.StatusBadRequest, http.StatusConflict, http.StatusInternalServerError} errorCodesToTest := [3]int{http.StatusBadRequest, http.StatusConflict, http.StatusInternalServerError}
for _, errorCodeToTest := range errorCodesToTest { for _, errorCodeToTest := range errorCodesToTest {
clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
client, err = RegisterClient(&clientconfig, &http.Client{}) client, err = RegisterClient(&clientconfig, &http.Client{})
if client != nil || err == nil { if client != nil || err == nil {
t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest) t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest)
@ -112,7 +116,6 @@ func TestWatcherRegister(t *testing.T) {
} }
func TestWatcherAuth(t *testing.T) { func TestWatcherAuth(t *testing.T) {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
mux, urlx, teardown := setup() mux, urlx, teardown := setup()
@ -121,6 +124,7 @@ func TestWatcherAuth(t *testing.T) {
initBasicMuxMock(t, mux, "/watchers/login") initBasicMuxMock(t, mux, "/watchers/login")
log.Printf("URL is %s", urlx) log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/") apiURL, err := url.Parse(urlx + "/")
if err != nil { if err != nil {
t.Fatalf("parsing api url: %s", apiURL) t.Fatalf("parsing api url: %s", apiURL)
@ -169,6 +173,7 @@ func TestWatcherAuth(t *testing.T) {
if err == nil { if err == nil {
resp.Response.Body.Close() resp.Response.Body.Close()
bodyBytes, err := io.ReadAll(resp.Response.Body) bodyBytes, err := io.ReadAll(resp.Response.Body)
if err != nil { if err != nil {
t.Fatalf("error while reading body: %s", err.Error()) t.Fatalf("error while reading body: %s", err.Error())
@ -176,14 +181,13 @@ func TestWatcherAuth(t *testing.T) {
log.Printf(string(bodyBytes)) log.Printf(string(bodyBytes))
t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest) t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest)
} else {
log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest)
} }
log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest)
} }
} }
func TestWatcherUnregister(t *testing.T) { func TestWatcherUnregister(t *testing.T) {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
mux, urlx, teardown := setup() mux, urlx, teardown := setup()
@ -192,7 +196,7 @@ func TestWatcherUnregister(t *testing.T) {
mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "DELETE") testMethod(t, r, "DELETE")
assert.Equal(t, r.ContentLength, int64(0)) assert.Equal(t, int64(0), r.ContentLength)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
@ -211,10 +215,12 @@ func TestWatcherUnregister(t *testing.T) {
}) })
log.Printf("URL is %s", urlx) log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/") apiURL, err := url.Parse(urlx + "/")
if err != nil { if err != nil {
t.Fatalf("parsing api url: %s", apiURL) t.Fatalf("parsing api url: %s", apiURL)
} }
mycfg := &Config{ mycfg := &Config{
MachineID: "test_login", MachineID: "test_login",
Password: "test_password", Password: "test_password",
@ -228,10 +234,12 @@ func TestWatcherUnregister(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("new api client: %s", err) t.Fatalf("new api client: %s", err)
} }
_, err = client.Auth.UnregisterWatcher(context.Background()) _, err = client.Auth.UnregisterWatcher(context.Background())
if err != nil { if err != nil {
t.Fatalf("while registering client : %s", err) t.Fatalf("while registering client : %s", err)
} }
log.Printf("->%T", client) log.Printf("->%T", client)
} }
@ -264,6 +272,7 @@ func TestWatcherEnroll(t *testing.T) {
fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`) fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`)
}) })
log.Printf("URL is %s", urlx) log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/") apiURL, err := url.Parse(urlx + "/")
if err != nil { if err != nil {
t.Fatalf("parsing api url: %s", apiURL) t.Fatalf("parsing api url: %s", apiURL)

View file

@ -18,7 +18,7 @@ func TestApiAuth(t *testing.T) {
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET") testMethod(t, r, "GET")
if r.Header.Get("X-Api-Key") == "ixu" { if r.Header.Get("X-Api-Key") == "ixu" {
assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`null`)) w.Write([]byte(`null`))
} else { } else {
@ -66,9 +66,11 @@ func TestApiAuth(t *testing.T) {
_, resp, err = newcli.Decisions.List(context.Background(), alert) _, resp, err = newcli.Decisions.List(context.Background(), alert)
log.Infof("--> %s", err) log.Infof("--> %s", err)
if resp.Response.StatusCode != http.StatusForbidden { if resp.Response.StatusCode != http.StatusForbidden {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
} }
assert.Contains(t, err.Error(), "API error: access forbidden") assert.Contains(t, err.Error(), "API error: access forbidden")
//ko empty token //ko empty token
auth = &APIKeyTransport{} auth = &APIKeyTransport{}
@ -82,5 +84,4 @@ func TestApiAuth(t *testing.T) {
log.Infof("--> %s", err) log.Infof("--> %s", err)
assert.Contains(t, err.Error(), "APIKey is empty") assert.Contains(t, err.Error(), "APIKey is empty")
} }

View file

@ -10,6 +10,8 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"github.com/golang-jwt/jwt/v4"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
) )
@ -43,6 +45,21 @@ func (a *ApiClient) GetClient() *http.Client {
return a.client return a.client
} }
func (a *ApiClient) IsEnrolled() bool {
jwtTransport := a.client.Transport.(*JWTTransport)
tokenStr := jwtTransport.Token
token, _ := jwt.Parse(tokenStr, nil)
if token == nil {
return false
}
claims := token.Claims.(jwt.MapClaims)
_, ok := claims["organization_id"]
return ok
}
type service struct { type service struct {
client *ApiClient client *ApiClient
} }
@ -59,12 +76,15 @@ func NewClient(config *Config) (*ApiClient, error) {
} }
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
tlsconfig.RootCAs = CaCertPool tlsconfig.RootCAs = CaCertPool
if Cert != nil { if Cert != nil {
tlsconfig.Certificates = []tls.Certificate{*Cert} tlsconfig.Certificates = []tls.Certificate{*Cert}
} }
if ht, ok := http.DefaultTransport.(*http.Transport); ok { if ht, ok := http.DefaultTransport.(*http.Transport); ok {
ht.TLSClientConfig = &tlsconfig ht.TLSClientConfig = &tlsconfig
} }
c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL} c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL}
c.common.client = c c.common.client = c
c.Decisions = (*DecisionsService)(&c.common) c.Decisions = (*DecisionsService)(&c.common)
@ -81,16 +101,20 @@ func NewClient(config *Config) (*ApiClient, error) {
func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) { func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) {
if client == nil { if client == nil {
client = &http.Client{} client = &http.Client{}
if ht, ok := http.DefaultTransport.(*http.Transport); ok { if ht, ok := http.DefaultTransport.(*http.Transport); ok {
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
tlsconfig.RootCAs = CaCertPool tlsconfig.RootCAs = CaCertPool
if Cert != nil { if Cert != nil {
tlsconfig.Certificates = []tls.Certificate{*Cert} tlsconfig.Certificates = []tls.Certificate{*Cert}
} }
ht.TLSClientConfig = &tlsconfig ht.TLSClientConfig = &tlsconfig
client.Transport = ht client.Transport = ht
} }
} }
c := &ApiClient{client: client, BaseURL: URL, UserAgent: userAgent, URLPrefix: prefix} c := &ApiClient{client: client, BaseURL: URL, UserAgent: userAgent, URLPrefix: prefix}
c.common.client = c c.common.client = c
c.Decisions = (*DecisionsService)(&c.common) c.Decisions = (*DecisionsService)(&c.common)
@ -108,11 +132,13 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
if client == nil { if client == nil {
client = &http.Client{} client = &http.Client{}
} }
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
if Cert != nil { if Cert != nil {
tlsconfig.RootCAs = CaCertPool tlsconfig.RootCAs = CaCertPool
tlsconfig.Certificates = []tls.Certificate{*Cert} tlsconfig.Certificates = []tls.Certificate{*Cert}
} }
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig
c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix}
c.common.client = c c.common.client = c
@ -126,10 +152,11 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
if resp != nil && resp.Response != nil { if resp != nil && resp.Response != nil {
return nil, fmt.Errorf("api register (%s) http %s: %w", c.BaseURL, resp.Response.Status, err) return nil, fmt.Errorf("api register (%s) http %s: %w", c.BaseURL, resp.Response.Status, err)
} }
return nil, fmt.Errorf("api register (%s): %w", c.BaseURL, err) return nil, fmt.Errorf("api register (%s): %w", c.BaseURL, err)
} }
return c, nil
return c, nil
} }
type Response struct { type Response struct {
@ -148,6 +175,7 @@ func (e *ErrorResponse) Error() string {
if len(e.Errors) > 0 { if len(e.Errors) > 0 {
err += fmt.Sprintf(" (%s)", e.Errors) err += fmt.Sprintf(" (%s)", e.Errors)
} }
return err return err
} }
@ -160,7 +188,9 @@ func CheckResponse(r *http.Response) error {
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
return nil return nil
} }
errorResponse := &ErrorResponse{} errorResponse := &ErrorResponse{}
data, err := io.ReadAll(r.Body) data, err := io.ReadAll(r.Body)
if err == nil && data != nil { if err == nil && data != nil {
err := json.Unmarshal(data, errorResponse) err := json.Unmarshal(data, errorResponse)
@ -171,6 +201,7 @@ func CheckResponse(r *http.Response) error {
errorResponse.Message = new(string) errorResponse.Message = new(string)
*errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode) *errorResponse.Message = fmt.Sprintf("http code %d, no error message", r.StatusCode)
} }
return errorResponse return errorResponse
} }

View file

@ -19,6 +19,7 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ
if !strings.HasSuffix(c.BaseURL.Path, "/") { if !strings.HasSuffix(c.BaseURL.Path, "/") {
return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL) return nil, fmt.Errorf("BaseURL must have a trailing slash, but %q does not", c.BaseURL)
} }
u, err := c.BaseURL.Parse(url) u, err := c.BaseURL.Parse(url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -29,8 +30,8 @@ func (c *ApiClient) NewRequest(method, url string, body interface{}) (*http.Requ
buf = &bytes.Buffer{} buf = &bytes.Buffer{}
enc := json.NewEncoder(buf) enc := json.NewEncoder(buf)
enc.SetEscapeHTML(false) enc.SetEscapeHTML(false)
err := enc.Encode(body)
if err != nil { if err = enc.Encode(body); err != nil {
return nil, err return nil, err
} }
} }
@ -51,6 +52,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (*
if ctx == nil { if ctx == nil {
return nil, errors.New("context must be non-nil") return nil, errors.New("context must be non-nil")
} }
req = req.WithContext(ctx) req = req.WithContext(ctx)
// Check rate limit // Check rate limit
@ -62,6 +64,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (*
if log.GetLevel() >= log.DebugLevel { if log.GetLevel() >= log.DebugLevel {
log.Debugf("[URL] %s %s", req.Method, req.URL) log.Debugf("[URL] %s %s", req.Method, req.URL)
} }
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
@ -82,8 +85,10 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (*
e.URL = url.String() e.URL = url.String()
return newResponse(resp), e return newResponse(resp), e
} }
return newResponse(resp), err return newResponse(resp), err
} }
return newResponse(resp), err return newResponse(resp), err
} }
@ -112,9 +117,12 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (*
if errors.Is(decErr, io.EOF) { if errors.Is(decErr, io.EOF) {
decErr = nil // ignore EOF errors caused by empty response body decErr = nil // ignore EOF errors caused by empty response body
} }
return response, decErr return response, decErr
} }
io.Copy(w, resp.Body) io.Copy(w, resp.Body)
} }
return response, err return response, err
} }

View file

@ -21,6 +21,7 @@ func TestNewRequestInvalid(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("parsing api url: %s", apiURL) t.Fatalf("parsing api url: %s", apiURL)
} }
client, err := NewClient(&Config{ client, err := NewClient(&Config{
MachineID: "test_login", MachineID: "test_login",
Password: "test_password", Password: "test_password",
@ -54,6 +55,7 @@ func TestNewRequestTimeout(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("parsing api url: %s", apiURL) t.Fatalf("parsing api url: %s", apiURL)
} }
client, err := NewClient(&Config{ client, err := NewClient(&Config{
MachineID: "test_login", MachineID: "test_login",
Password: "test_password", Password: "test_password",

View file

@ -40,6 +40,7 @@ func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, te
func testMethod(t *testing.T, r *http.Request, want string) { func testMethod(t *testing.T, r *http.Request, want string) {
t.Helper() t.Helper()
if got := r.Method; got != want { if got := r.Method; got != want {
t.Errorf("Request method: %v, want %v", got, want) t.Errorf("Request method: %v, want %v", got, want)
} }
@ -77,6 +78,7 @@ func TestNewClientOk(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("test Unable to list alerts : %+v", err) t.Fatalf("test Unable to list alerts : %+v", err)
} }
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated) t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated)
} }
@ -126,6 +128,7 @@ func TestNewDefaultClient(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("new api client: %s", err) t.Fatalf("new api client: %s", err)
} }
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"code": 401, "message" : "brr"}`)) w.Write([]byte(`{"code": 401, "message" : "brr"}`))
@ -157,6 +160,7 @@ func TestNewClientRegisterKO(t *testing.T) {
func TestNewClientRegisterOK(t *testing.T) { func TestNewClientRegisterOK(t *testing.T) {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
mux, urlx, teardown := setup() mux, urlx, teardown := setup()
defer teardown() defer teardown()
/*mock login*/ /*mock login*/
@ -180,12 +184,14 @@ func TestNewClientRegisterOK(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("while registering client : %s", err) t.Fatalf("while registering client : %s", err)
} }
log.Printf("->%T", client) log.Printf("->%T", client)
} }
func TestNewClientBadAnswer(t *testing.T) { func TestNewClientBadAnswer(t *testing.T) {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
mux, urlx, teardown := setup() mux, urlx, teardown := setup()
defer teardown() defer teardown()
/*mock login*/ /*mock login*/

View file

@ -42,6 +42,7 @@ func (o *DecisionsStreamOpts) addQueryParamsToURL(url string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return fmt.Sprintf("%s?%s", url, params.Encode()), nil return fmt.Sprintf("%s?%s", url, params.Encode()), nil
} }
@ -61,10 +62,12 @@ type DecisionsDeleteOpts struct {
// to demo query arguments // to demo query arguments
func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*models.GetDecisionsResponse, *Response, error) { func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*models.GetDecisionsResponse, *Response, error) {
var decisions models.GetDecisionsResponse var decisions models.GetDecisionsResponse
params, err := qs.Values(opts) params, err := qs.Values(opts)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode()) u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
req, err := s.client.NewRequest(http.MethodGet, u, nil) req, err := s.client.NewRequest(http.MethodGet, u, nil)
@ -111,14 +114,18 @@ func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi.
Origin: ptr.Of(types.CAPIOrigin), Origin: ptr.Of(types.CAPIOrigin),
} }
} }
decisions = append(decisions, partialDecisions...) decisions = append(decisions, partialDecisions...)
} }
return decisions return decisions
} }
func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) { func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) {
var decisions modelscapi.GetDecisionsStreamResponse var (
var v2Decisions models.DecisionsStreamResponse decisions modelscapi.GetDecisionsStreamResponse
v2Decisions models.DecisionsStreamResponse
)
scenarioDeleted := "deleted" scenarioDeleted := "deleted"
durationDeleted := "1h" durationDeleted := "1h"
@ -134,8 +141,10 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
} }
v2Decisions.New = s.GetDecisionsFromGroups(decisions.New) v2Decisions.New = s.GetDecisionsFromGroups(decisions.New)
for _, decisionsGroup := range decisions.Deleted { for _, decisionsGroup := range decisions.Deleted {
partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions)) partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions))
for idx, decision := range decisionsGroup.Decisions { for idx, decision := range decisionsGroup.Decisions {
decision := decision // fix exportloopref linter message decision := decision // fix exportloopref linter message
partialDecisions[idx] = &models.Decision{ partialDecisions[idx] = &models.Decision{
@ -147,6 +156,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
Origin: ptr.Of(types.CAPIOrigin), Origin: ptr.Of(types.CAPIOrigin),
} }
} }
v2Decisions.Deleted = append(v2Decisions.Deleted, partialDecisions...) v2Decisions.Deleted = append(v2Decisions.Deleted, partialDecisions...)
} }
@ -161,6 +171,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
log.Debugf("Fetching blocklist %s", *blocklist.URL) log.Debugf("Fetching blocklist %s", *blocklist.URL)
client := http.Client{} client := http.Client{}
req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil) req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
@ -169,6 +180,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
if lastPullTimestamp != nil { if lastPullTimestamp != nil {
req.Header.Set("If-Modified-Since", *lastPullTimestamp) req.Header.Set("If-Modified-Since", *lastPullTimestamp)
} }
req = req.WithContext(ctx) req = req.WithContext(ctx)
log.Debugf("[URL] %s %s", req.Method, req.URL) log.Debugf("[URL] %s %s", req.Method, req.URL)
// we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc // we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc
@ -188,6 +200,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
// If the error type is *url.Error, sanitize its URL before returning. // If the error type is *url.Error, sanitize its URL before returning.
log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err) log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err)
return nil, false, err return nil, false, err
} }
@ -197,13 +210,17 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
} else { } else {
log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL) log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL)
} }
return nil, false, nil return nil, false, nil
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL) log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL)
return nil, false, nil return nil, false, nil
} }
decisions := make([]*models.Decision, 0) decisions := make([]*models.Decision, 0)
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() { for scanner.Scan() {
decision := scanner.Text() decision := scanner.Text()
@ -227,6 +244,7 @@ func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOp
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if s.client.URLPrefix == "v3" { if s.client.URLPrefix == "v3" {
return s.FetchV3Decisions(ctx, u) return s.FetchV3Decisions(ctx, u)
} else { } else {
@ -239,6 +257,7 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
var decisions modelscapi.GetDecisionsStreamResponse var decisions modelscapi.GetDecisionsStreamResponse
req, err := s.client.NewRequest(http.MethodGet, u, nil) req, err := s.client.NewRequest(http.MethodGet, u, nil)
@ -255,8 +274,8 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream
} }
func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) { func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) {
u := fmt.Sprintf("%s/decisions", s.client.URLPrefix) u := fmt.Sprintf("%s/decisions", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodDelete, u, nil) req, err := s.client.NewRequest(http.MethodDelete, u, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -266,15 +285,18 @@ func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) {
if err != nil { if err != nil {
return resp, err return resp, err
} }
return resp, nil return resp, nil
} }
func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) (*models.DeleteDecisionResponse, *Response, error) { func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) (*models.DeleteDecisionResponse, *Response, error) {
var deleteDecisionResponse models.DeleteDecisionResponse var deleteDecisionResponse models.DeleteDecisionResponse
params, err := qs.Values(opts) params, err := qs.Values(opts)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode()) u := fmt.Sprintf("%s/decisions?%s", s.client.URLPrefix, params.Encode())
req, err := s.client.NewRequest(http.MethodDelete, u, nil) req, err := s.client.NewRequest(http.MethodDelete, u, nil)
@ -286,12 +308,14 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts)
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return &deleteDecisionResponse, resp, nil return &deleteDecisionResponse, resp, nil
} }
func (s *DecisionsService) DeleteOne(ctx context.Context, decision_id string) (*models.DeleteDecisionResponse, *Response, error) { func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*models.DeleteDecisionResponse, *Response, error) {
var deleteDecisionResponse models.DeleteDecisionResponse var deleteDecisionResponse models.DeleteDecisionResponse
u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decision_id)
u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decisionID)
req, err := s.client.NewRequest(http.MethodDelete, u, nil) req, err := s.client.NewRequest(http.MethodDelete, u, nil)
if err != nil { if err != nil {
@ -302,5 +326,6 @@ func (s *DecisionsService) DeleteOne(ctx context.Context, decision_id string) (*
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return &deleteDecisionResponse, resp, nil return &deleteDecisionResponse, resp, nil
} }

View file

@ -28,8 +28,8 @@ func TestDecisionsList(t *testing.T) {
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET") testMethod(t, r, "GET")
if r.URL.RawQuery == "ip=1.2.3.4" { if r.URL.RawQuery == "ip=1.2.3.4" {
assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`[{"duration":"3h59m55.756182786s","id":4,"origin":"cscli","scenario":"manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'","scope":"Ip","type":"ban","value":"1.2.3.4"}]`)) w.Write([]byte(`[{"duration":"3h59m55.756182786s","id":4,"origin":"cscli","scenario":"manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'","scope":"Ip","type":"ban","value":"1.2.3.4"}]`))
} else { } else {
@ -83,6 +83,7 @@ func TestDecisionsList(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("new api client: %s", err) t.Fatalf("new api client: %s", err)
} }
if !reflect.DeepEqual(*decisions, *expected) { if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected) t.Fatalf("returned %+v, want %+v", resp, expected)
} }
@ -96,8 +97,8 @@ func TestDecisionsList(t *testing.T) {
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
} }
assert.Equal(t, len(*decisions), 0)
assert.Empty(t, *decisions)
} }
func TestDecisionsStream(t *testing.T) { func TestDecisionsStream(t *testing.T) {
@ -107,8 +108,7 @@ func TestDecisionsStream(t *testing.T) {
defer teardown() defer teardown()
mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
testMethod(t, r, http.MethodGet) testMethod(t, r, http.MethodGet)
if r.Method == http.MethodGet { if r.Method == http.MethodGet {
if r.URL.RawQuery == "startup=true" { if r.URL.RawQuery == "startup=true" {
@ -121,7 +121,7 @@ func TestDecisionsStream(t *testing.T) {
} }
}) })
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
testMethod(t, r, http.MethodDelete) testMethod(t, r, http.MethodDelete)
if r.Method == http.MethodDelete { if r.Method == http.MethodDelete {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@ -173,6 +173,7 @@ func TestDecisionsStream(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("new api client: %s", err) t.Fatalf("new api client: %s", err)
} }
if !reflect.DeepEqual(*decisions, *expected) { if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected) t.Fatalf("returned %+v, want %+v", resp, expected)
} }
@ -184,8 +185,9 @@ func TestDecisionsStream(t *testing.T) {
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
} }
assert.Equal(t, 0, len(decisions.New))
assert.Equal(t, 0, len(decisions.Deleted)) assert.Empty(t, decisions.New)
assert.Empty(t, decisions.Deleted)
//delete stream //delete stream
resp, err = newcli.Decisions.StopStream(context.Background()) resp, err = newcli.Decisions.StopStream(context.Background())
@ -203,8 +205,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
defer teardown() defer teardown()
mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
testMethod(t, r, http.MethodGet) testMethod(t, r, http.MethodGet)
if r.Method == http.MethodGet { if r.Method == http.MethodGet {
if r.URL.RawQuery == "startup=true" { if r.URL.RawQuery == "startup=true" {
@ -275,6 +276,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("new api client: %s", err) t.Fatalf("new api client: %s", err)
} }
if !reflect.DeepEqual(*decisions, *expected) { if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected) t.Fatalf("returned %+v, want %+v", resp, expected)
} }
@ -287,8 +289,7 @@ func TestDecisionsStreamV3(t *testing.T) {
defer teardown() defer teardown()
mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/decisions/stream", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
testMethod(t, r, http.MethodGet) testMethod(t, r, http.MethodGet)
if r.Method == http.MethodGet { if r.Method == http.MethodGet {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@ -368,6 +369,7 @@ func TestDecisionsStreamV3(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("new api client: %s", err) t.Fatalf("new api client: %s", err)
} }
if !reflect.DeepEqual(*decisions, *expected) { if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected) t.Fatalf("returned %+v, want %+v", resp, expected)
} }
@ -451,6 +453,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("new api client: %s", err) t.Fatalf("new api client: %s", err)
} }
if !reflect.DeepEqual(decisions, expected) { if !reflect.DeepEqual(decisions, expected) {
t.Fatalf("returned %+v, want %+v", decisions, expected) t.Fatalf("returned %+v, want %+v", decisions, expected)
} }
@ -484,7 +487,7 @@ func TestDeleteDecisions(t *testing.T) {
}) })
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "DELETE") testMethod(t, r, "DELETE")
assert.Equal(t, r.URL.RawQuery, "ip=1.2.3.4") assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"nbDeleted":"1"}`)) w.Write([]byte(`{"nbDeleted":"1"}`))
//w.Write([]byte(`{"message":"0 deleted alerts"}`)) //w.Write([]byte(`{"message":"0 deleted alerts"}`))
@ -512,6 +515,7 @@ func TestDeleteDecisions(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected err : %s", err) t.Fatalf("unexpected err : %s", err)
} }
assert.Equal(t, "1", deleted.NbDeleted) assert.Equal(t, "1", deleted.NbDeleted)
defer teardown() defer teardown()
@ -519,6 +523,7 @@ func TestDeleteDecisions(t *testing.T) {
func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
baseURLString := "http://localhost:8080/v1/decisions/stream" baseURLString := "http://localhost:8080/v1/decisions/stream"
type fields struct { type fields struct {
Startup bool Startup bool
Scopes string Scopes string
@ -553,6 +558,7 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {

View file

@ -15,7 +15,9 @@ type DecisionDeleteService service
// DecisionDeleteService purposely reuses AddSignalsRequestItemDecisions model // DecisionDeleteService purposely reuses AddSignalsRequestItemDecisions model
func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) { func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) {
var response interface{} var response interface{}
u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix) u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix)
req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions) req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("while building request: %w", err) return nil, nil, fmt.Errorf("while building request: %w", err)
@ -25,10 +27,12 @@ func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *model
if err != nil { if err != nil {
return nil, resp, fmt.Errorf("while performing request: %w", err) return nil, resp, fmt.Errorf("while performing request: %w", err)
} }
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
log.Warnf("Decisions delete response : http %s", resp.Response.Status) log.Warnf("Decisions delete response : http %s", resp.Response.Status)
} else { } else {
log.Debugf("Decisions delete response : http %s", resp.Response.Status) log.Debugf("Decisions delete response : http %s", resp.Response.Status)
} }
return &response, resp, nil return &response, resp, nil
} }

View file

@ -15,7 +15,6 @@ import (
type HeartBeatService service type HeartBeatService service
func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) { func (h *HeartBeatService) Ping(ctx context.Context) (bool, *Response, error) {
u := fmt.Sprintf("%s/heartbeat", h.client.URLPrefix) u := fmt.Sprintf("%s/heartbeat", h.client.URLPrefix)
req, err := h.client.NewRequest(http.MethodGet, u, nil) req, err := h.client.NewRequest(http.MethodGet, u, nil)

View file

@ -14,6 +14,7 @@ func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (inte
var response interface{} var response interface{}
u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix) u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &metrics) req, err := s.client.NewRequest(http.MethodPost, u, &metrics)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -23,5 +24,6 @@ func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (inte
if err != nil { if err != nil {
return nil, resp, err return nil, resp, err
} }
return &response, resp, nil return &response, resp, nil
} }

View file

@ -16,6 +16,7 @@ func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsReque
var response interface{} var response interface{}
u := fmt.Sprintf("%s/signals", s.client.URLPrefix) u := fmt.Sprintf("%s/signals", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &signals) req, err := s.client.NewRequest(http.MethodPost, u, &signals)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("while building request: %w", err) return nil, nil, fmt.Errorf("while building request: %w", err)
@ -25,10 +26,12 @@ func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsReque
if err != nil { if err != nil {
return nil, resp, fmt.Errorf("while performing request: %w", err) return nil, resp, fmt.Errorf("while performing request: %w", err)
} }
if resp.Response.StatusCode != http.StatusOK { if resp.Response.StatusCode != http.StatusOK {
log.Warnf("Signal push response : http %s", resp.Response.Status) log.Warnf("Signal push response : http %s", resp.Response.Status)
} else { } else {
log.Debugf("Signal push response : http %s", resp.Response.Status) log.Debugf("Signal push response : http %s", resp.Response.Status)
} }
return &response, resp, nil return &response, resp, nil
} }

View file

@ -9,13 +9,13 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/csplugin"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
) )
type LAPI struct { type LAPI struct {
@ -57,6 +57,7 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut
if err != nil { if err != nil {
l.t.Fatal(err) l.t.Fatal(err)
} }
if authType == "apikey" { if authType == "apikey" {
req.Header.Add("X-Api-Key", l.bouncerKey) req.Header.Add("X-Api-Key", l.bouncerKey)
} else if authType == "password" { } else if authType == "password" {
@ -64,7 +65,9 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut
} else { } else {
l.t.Fatal("auth type not supported") l.t.Fatal("auth type not supported")
} }
l.router.ServeHTTP(w, req) l.router.ServeHTTP(w, req)
return w return w
} }
@ -78,6 +81,7 @@ func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csc
if err != nil { if err != nil {
return nil, models.WatcherAuthResponse{}, config, err return nil, models.WatcherAuthResponse{}, config, err
} }
return router, loginResp, config, nil return router, loginResp, config, nil
} }
@ -150,7 +154,6 @@ func TestCreateAlert(t *testing.T) {
} }
func TestCreateAlertChannels(t *testing.T) { func TestCreateAlertChannels(t *testing.T) {
apiServer, config, err := NewAPIServer(t) apiServer, config, err := NewAPIServer(t)
if err != nil { if err != nil {
log.Fatalln(err) log.Fatalln(err)
@ -164,18 +167,22 @@ func TestCreateAlertChannels(t *testing.T) {
} }
lapi := LAPI{router: apiServer.router, loginResp: loginResp} lapi := LAPI{router: apiServer.router, loginResp: loginResp}
var pd csplugin.ProfileAlert var (
var wg sync.WaitGroup pd csplugin.ProfileAlert
wg sync.WaitGroup
)
wg.Add(1) wg.Add(1)
go func() { go func() {
pd = <-apiServer.controller.PluginChannel pd = <-apiServer.controller.PluginChannel
wg.Done() wg.Done()
}() }()
go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json") go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json")
wg.Wait() wg.Wait()
assert.Equal(t, len(pd.Alert.Decisions), 1) assert.Len(t, pd.Alert.Decisions, 1)
apiServer.Close() apiServer.Close()
} }
@ -345,7 +352,6 @@ func TestAlertListFilters(t *testing.T) {
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String())
} }
func TestAlertBulkInsert(t *testing.T) { func TestAlertBulkInsert(t *testing.T) {
@ -393,7 +399,6 @@ func TestCreateAlertErrors(t *testing.T) {
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s")) req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", lapi.loginResp.Token+"s"))
lapi.router.ServeHTTP(w, req) lapi.router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code) assert.Equal(t, 401, w.Code)
} }
func TestDeleteAlert(t *testing.T) { func TestDeleteAlert(t *testing.T) {
@ -506,5 +511,4 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
lapi.InsertAlertFromFile("./tests/alert_sample.json") lapi.InsertAlertFromFile("./tests/alert_sample.json")
assertAlertDeletedFromIP("127.0.0.1") assertAlertDeletedFromIP("127.0.0.1")
} }

View file

@ -48,5 +48,4 @@ func TestAPIKey(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String()) assert.Equal(t, "null", w.Body.String())
} }

View file

@ -75,12 +75,14 @@ func randomDuration(d time.Duration, delta time.Duration) time.Duration {
if ret <= 0 { if ret <= 0 {
return 1 return 1
} }
return ret return ret
} }
func (a *apic) FetchScenariosListFromDB() ([]string, error) { func (a *apic) FetchScenariosListFromDB() ([]string, error) {
scenarios := make([]string, 0) scenarios := make([]string, 0)
machines, err := a.dbClient.ListMachines() machines, err := a.dbClient.ListMachines()
if err != nil { if err != nil {
return nil, fmt.Errorf("while listing machines: %w", err) return nil, fmt.Errorf("while listing machines: %w", err)
} }
@ -88,18 +90,22 @@ func (a *apic) FetchScenariosListFromDB() ([]string, error) {
for _, v := range machines { for _, v := range machines {
machineScenarios := strings.Split(v.Scenarios, ",") machineScenarios := strings.Split(v.Scenarios, ",")
log.Debugf("%d scenarios for machine %d", len(machineScenarios), v.ID) log.Debugf("%d scenarios for machine %d", len(machineScenarios), v.ID)
for _, sv := range machineScenarios { for _, sv := range machineScenarios {
if !slices.Contains(scenarios, sv) && sv != "" { if !slices.Contains(scenarios, sv) && sv != "" {
scenarios = append(scenarios, sv) scenarios = append(scenarios, sv)
} }
} }
} }
log.Debugf("Returning list of scenarios : %+v", scenarios) log.Debugf("Returning list of scenarios : %+v", scenarios)
return scenarios, nil return scenarios, nil
} }
func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequestItemDecisions { func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequestItemDecisions {
apiDecisions := models.AddSignalsRequestItemDecisions{} apiDecisions := models.AddSignalsRequestItemDecisions{}
for _, decision := range decisions { for _, decision := range decisions {
x := &models.AddSignalsRequestItemDecisionsItem{ x := &models.AddSignalsRequestItemDecisionsItem{
Duration: ptr.Of(*decision.Duration), Duration: ptr.Of(*decision.Duration),
@ -114,11 +120,14 @@ func decisionsToApiDecisions(decisions []*models.Decision) models.AddSignalsRequ
UUID: decision.UUID, UUID: decision.UUID,
} }
*x.ID = decision.ID *x.ID = decision.ID
if decision.Simulated != nil { if decision.Simulated != nil {
x.Simulated = *decision.Simulated x.Simulated = *decision.Simulated
} }
apiDecisions = append(apiDecisions, x) apiDecisions = append(apiDecisions, x)
} }
return apiDecisions return apiDecisions
} }
@ -149,6 +158,7 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool)
} }
if shareContext { if shareContext {
signal.Context = make([]*models.AddSignalsRequestItemContextItems0, 0) signal.Context = make([]*models.AddSignalsRequestItemContextItems0, 0)
for _, meta := range alert.Meta { for _, meta := range alert.Meta {
contextItem := models.AddSignalsRequestItemContextItems0{ contextItem := models.AddSignalsRequestItemContextItems0{
Key: meta.Key, Key: meta.Key,
@ -157,13 +167,14 @@ func alertToSignal(alert *models.Alert, scenarioTrust string, shareContext bool)
signal.Context = append(signal.Context, &contextItem) signal.Context = append(signal.Context, &contextItem)
} }
} }
return signal return signal
} }
func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) { func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, apicWhitelist *csconfig.CapiWhitelist) (*apic, error) {
var err error var err error
ret := &apic{
ret := &apic{
AlertsAddChan: make(chan []*models.Alert), AlertsAddChan: make(chan []*models.Alert),
dbClient: dbClient, dbClient: dbClient,
mu: sync.Mutex{}, mu: sync.Mutex{},
@ -186,9 +197,11 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con
password := strfmt.Password(config.Credentials.Password) password := strfmt.Password(config.Credentials.Password)
apiURL, err := url.Parse(config.Credentials.URL) apiURL, err := url.Parse(config.Credentials.URL)
if err != nil { if err != nil {
return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.URL, err) return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.URL, err)
} }
papiURL, err := url.Parse(config.Credentials.PapiURL) papiURL, err := url.Parse(config.Credentials.PapiURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err) return nil, fmt.Errorf("while parsing '%s': %w", config.Credentials.PapiURL, err)
@ -198,6 +211,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con
if err != nil { if err != nil {
return nil, fmt.Errorf("while fetching scenarios from db: %w", err) return nil, fmt.Errorf("while fetching scenarios from db: %w", err)
} }
ret.apiClient, err = apiclient.NewClient(&apiclient.Config{ ret.apiClient, err = apiclient.NewClient(&apiclient.Config{
MachineID: config.Credentials.Login, MachineID: config.Credentials.Login,
Password: password, Password: password,
@ -228,7 +242,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con
return ret, fmt.Errorf("authenticate watcher (%s): %w", config.Credentials.Login, err) return ret, fmt.Errorf("authenticate watcher (%s): %w", config.Credentials.Login, err)
} }
if err := ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { if err = ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil {
return ret, fmt.Errorf("unable to parse jwt expiration: %w", err) return ret, fmt.Errorf("unable to parse jwt expiration: %w", err)
} }
@ -242,6 +256,7 @@ func (a *apic) Push() error {
defer trace.CatchPanic("lapi/pushToAPIC") defer trace.CatchPanic("lapi/pushToAPIC")
var cache models.AddSignalsRequest var cache models.AddSignalsRequest
ticker := time.NewTicker(a.pushIntervalFirst) ticker := time.NewTicker(a.pushIntervalFirst)
log.Infof("Start push to CrowdSec Central API (interval: %s once, then %s)", a.pushIntervalFirst.Round(time.Second), a.pushInterval) log.Infof("Start push to CrowdSec Central API (interval: %s once, then %s)", a.pushIntervalFirst.Round(time.Second), a.pushInterval)
@ -252,28 +267,35 @@ func (a *apic) Push() error {
a.pullTomb.Kill(nil) a.pullTomb.Kill(nil)
a.metricsTomb.Kill(nil) a.metricsTomb.Kill(nil)
log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache)) log.Infof("push tomb is dying, sending cache (%d elements) before exiting", len(cache))
if len(cache) == 0 { if len(cache) == 0 {
return nil return nil
} }
go a.Send(&cache) go a.Send(&cache)
return nil return nil
case <-ticker.C: case <-ticker.C:
ticker.Reset(a.pushInterval) ticker.Reset(a.pushInterval)
if len(cache) > 0 { if len(cache) > 0 {
a.mu.Lock() a.mu.Lock()
cacheCopy := cache cacheCopy := cache
cache = make(models.AddSignalsRequest, 0) cache = make(models.AddSignalsRequest, 0)
a.mu.Unlock() a.mu.Unlock()
log.Infof("Signal push: %d signals to push", len(cacheCopy)) log.Infof("Signal push: %d signals to push", len(cacheCopy))
go a.Send(&cacheCopy) go a.Send(&cacheCopy)
} }
case alerts := <-a.AlertsAddChan: case alerts := <-a.AlertsAddChan:
var signals []*models.AddSignalsRequestItem var signals []*models.AddSignalsRequestItem
for _, alert := range alerts { for _, alert := range alerts {
if ok := shouldShareAlert(alert, a.consoleConfig); ok { if ok := shouldShareAlert(alert, a.consoleConfig); ok {
signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext)) signals = append(signals, alertToSignal(alert, getScenarioTrustOfAlert(alert), *a.consoleConfig.ShareContext))
} }
} }
a.mu.Lock() a.mu.Lock()
cache = append(cache, signals...) cache = append(cache, signals...)
a.mu.Unlock() a.mu.Unlock()
@ -288,11 +310,13 @@ func getScenarioTrustOfAlert(alert *models.Alert) string {
} else if alert.ScenarioVersion == nil || *alert.ScenarioVersion == "" || *alert.ScenarioVersion == "?" { } else if alert.ScenarioVersion == nil || *alert.ScenarioVersion == "" || *alert.ScenarioVersion == "?" {
scenarioTrust = "tainted" scenarioTrust = "tainted"
} }
if len(alert.Decisions) > 0 { if len(alert.Decisions) > 0 {
if *alert.Decisions[0].Origin == types.CscliOrigin { if *alert.Decisions[0].Origin == types.CscliOrigin {
scenarioTrust = "manual" scenarioTrust = "manual"
} }
} }
return scenarioTrust return scenarioTrust
} }
@ -301,6 +325,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig
log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID) log.Debugf("simulation enabled for alert (id:%d), will not be sent to CAPI", alert.ID)
return false return false
} }
switch scenarioTrust := getScenarioTrustOfAlert(alert); scenarioTrust { switch scenarioTrust := getScenarioTrustOfAlert(alert); scenarioTrust {
case "manual": case "manual":
if !*consoleConfig.ShareManualDecisions { if !*consoleConfig.ShareManualDecisions {
@ -318,6 +343,7 @@ func shouldShareAlert(alert *models.Alert, consoleConfig *csconfig.ConsoleConfig
return false return false
} }
} }
return true return true
} }
@ -333,34 +359,44 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) {
I don't know enough about gin to tell how much of an issue it can be. I don't know enough about gin to tell how much of an issue it can be.
*/ */
var cache []*models.AddSignalsRequestItem = *cacheOrig var (
var send models.AddSignalsRequest cache []*models.AddSignalsRequestItem = *cacheOrig
send models.AddSignalsRequest
)
bulkSize := 50 bulkSize := 50
pageStart := 0 pageStart := 0
pageEnd := bulkSize pageEnd := bulkSize
for { for {
if pageEnd >= len(cache) { if pageEnd >= len(cache) {
send = cache[pageStart:] send = cache[pageStart:]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
_, _, err := a.apiClient.Signal.Add(ctx, &send) _, _, err := a.apiClient.Signal.Add(ctx, &send)
if err != nil { if err != nil {
log.Errorf("sending signal to central API: %s", err) log.Errorf("sending signal to central API: %s", err)
return return
} }
break break
} }
send = cache[pageStart:pageEnd] send = cache[pageStart:pageEnd]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
_, _, err := a.apiClient.Signal.Add(ctx, &send) _, _, err := a.apiClient.Signal.Add(ctx, &send)
if err != nil { if err != nil {
//we log it here as well, because the return value of func might be discarded //we log it here as well, because the return value of func might be discarded
log.Errorf("sending signal to central API: %s", err) log.Errorf("sending signal to central API: %s", err)
} }
pageStart += bulkSize pageStart += bulkSize
pageEnd += bulkSize pageEnd += bulkSize
} }
@ -372,18 +408,22 @@ func (a *apic) CAPIPullIsOld() (bool, error) {
alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID))) alerts = alerts.Where(alert.HasDecisionsWith(decision.OriginEQ(database.CapiMachineID)))
alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert alerts = alerts.Where(alert.CreatedAtGTE(time.Now().UTC().Add(-time.Duration(1*time.Hour + 30*time.Minute)))) //nolint:unconvert
count, err := alerts.Count(a.dbClient.CTX) count, err := alerts.Count(a.dbClient.CTX)
if err != nil { if err != nil {
return false, fmt.Errorf("while looking for CAPI alert: %w", err) return false, fmt.Errorf("while looking for CAPI alert: %w", err)
} }
if count > 0 { if count > 0 {
log.Printf("last CAPI pull is newer than 1h30, skip.") log.Printf("last CAPI pull is newer than 1h30, skip.")
return false, nil return false, nil
} }
return true, nil return true, nil
} }
func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delete_counters map[string]map[string]int) (int, error) { func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, deleteCounters map[string]map[string]int) (int, error) {
nbDeleted := 0 nbDeleted := 0
for _, decision := range deletedDecisions { for _, decision := range deletedDecisions {
filter := map[string][]string{ filter := map[string][]string{
"value": {*decision.Value}, "value": {*decision.Value},
@ -398,20 +438,25 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
if err != nil { if err != nil {
return 0, fmt.Errorf("deleting decisions error: %w", err) return 0, fmt.Errorf("deleting decisions error: %w", err)
} }
dbCliDel, err := strconv.Atoi(dbCliRet) dbCliDel, err := strconv.Atoi(dbCliRet)
if err != nil { if err != nil {
return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err) return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err)
} }
updateCounterForDecision(delete_counters, decision.Origin, decision.Scenario, dbCliDel)
updateCounterForDecision(deleteCounters, decision.Origin, decision.Scenario, dbCliDel)
nbDeleted += dbCliDel nbDeleted += dbCliDel
} }
return nbDeleted, nil return nbDeleted, nil
} }
func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, delete_counters map[string]map[string]int) (int, error) { func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, deleteCounters map[string]map[string]int) (int, error) {
var nbDeleted int var nbDeleted int
for _, decisions := range deletedDecisions { for _, decisions := range deletedDecisions {
scope := decisions.Scope scope := decisions.Scope
for _, decision := range decisions.Decisions { for _, decision := range decisions.Decisions {
filter := map[string][]string{ filter := map[string][]string{
"value": {decision}, "value": {decision},
@ -425,26 +470,32 @@ func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisi
if err != nil { if err != nil {
return 0, fmt.Errorf("deleting decisions error: %w", err) return 0, fmt.Errorf("deleting decisions error: %w", err)
} }
dbCliDel, err := strconv.Atoi(dbCliRet) dbCliDel, err := strconv.Atoi(dbCliRet)
if err != nil { if err != nil {
return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err) return 0, fmt.Errorf("converting db ret %d: %w", dbCliDel, err)
} }
updateCounterForDecision(delete_counters, ptr.Of(types.CAPIOrigin), nil, dbCliDel)
updateCounterForDecision(deleteCounters, ptr.Of(types.CAPIOrigin), nil, dbCliDel)
nbDeleted += dbCliDel nbDeleted += dbCliDel
} }
} }
return nbDeleted, nil return nbDeleted, nil
} }
func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert { func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert {
newAlerts := make([]*models.Alert, 0) newAlerts := make([]*models.Alert, 0)
for _, decision := range decisions { for _, decision := range decisions {
found := false found := false
for _, sub := range newAlerts { for _, sub := range newAlerts {
if sub.Source.Scope == nil { if sub.Source.Scope == nil {
log.Warningf("nil scope in %+v", sub) log.Warningf("nil scope in %+v", sub)
continue continue
} }
if *decision.Origin == types.CAPIOrigin { if *decision.Origin == types.CAPIOrigin {
if *sub.Source.Scope == types.CAPIOrigin { if *sub.Source.Scope == types.CAPIOrigin {
found = true found = true
@ -464,11 +515,13 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert {
log.Warningf("unknown origin %s : %+v", *decision.Origin, decision) log.Warningf("unknown origin %s : %+v", *decision.Origin, decision)
} }
} }
if !found { if !found {
log.Debugf("Create entry for origin:%s scenario:%s", *decision.Origin, *decision.Scenario) log.Debugf("Create entry for origin:%s scenario:%s", *decision.Origin, *decision.Scenario)
newAlerts = append(newAlerts, createAlertForDecision(decision)) newAlerts = append(newAlerts, createAlertForDecision(decision))
} }
} }
return newAlerts return newAlerts
} }
@ -489,6 +542,7 @@ func createAlertForDecision(decision *models.Decision) *models.Alert {
// XXX: this or nil? // XXX: this or nil?
scenario = "" scenario = ""
scope = "" scope = ""
log.Warningf("unknown origin %s", *decision.Origin) log.Warningf("unknown origin %s", *decision.Origin)
} }
@ -512,10 +566,10 @@ func createAlertForDecision(decision *models.Decision) *models.Alert {
} }
// This function takes in list of parent alerts and decisions and then pairs them up. // This function takes in list of parent alerts and decisions and then pairs them up.
func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, add_counters map[string]map[string]int) []*models.Alert { func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decision, addCounters map[string]map[string]int) []*models.Alert {
for _, decision := range decisions { for _, decision := range decisions {
//count and create separate alerts for each list //count and create separate alerts for each list
updateCounterForDecision(add_counters, decision.Origin, decision.Scenario, 1) updateCounterForDecision(addCounters, decision.Origin, decision.Scenario, 1)
/*CAPI might send lower case scopes, unify it.*/ /*CAPI might send lower case scopes, unify it.*/
switch strings.ToLower(*decision.Scope) { switch strings.ToLower(*decision.Scope) {
@ -524,6 +578,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
case "range": case "range":
*decision.Scope = types.Range *decision.Scope = types.Range
} }
found := false found := false
//add the individual decisions to the right list //add the individual decisions to the right list
for idx, alert := range alerts { for idx, alert := range alerts {
@ -531,6 +586,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
if *alert.Source.Scope == types.CAPIOrigin { if *alert.Source.Scope == types.CAPIOrigin {
alerts[idx].Decisions = append(alerts[idx].Decisions, decision) alerts[idx].Decisions = append(alerts[idx].Decisions, decision)
found = true found = true
break break
} }
} else if *decision.Origin == types.ListOrigin { } else if *decision.Origin == types.ListOrigin {
@ -543,10 +599,12 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
log.Warningf("unknown origin %s", *decision.Origin) log.Warningf("unknown origin %s", *decision.Origin)
} }
} }
if !found { if !found {
log.Warningf("Orphaned decision for %s - %s", *decision.Origin, *decision.Scenario) log.Warningf("Orphaned decision for %s - %s", *decision.Origin, *decision.Scenario)
} }
} }
return alerts return alerts
} }
@ -581,18 +639,20 @@ func (a *apic) PullTop(forcePull bool) error {
if err != nil { if err != nil {
return fmt.Errorf("get stream: %w", err) return fmt.Errorf("get stream: %w", err)
} }
a.startup = false a.startup = false
/*to count additions/deletions across lists*/ /*to count additions/deletions across lists*/
log.Debugf("Received %d new decisions", len(data.New)) log.Debugf("Received %d new decisions", len(data.New))
log.Debugf("Received %d deleted decisions", len(data.Deleted)) log.Debugf("Received %d deleted decisions", len(data.Deleted))
if data.Links != nil { if data.Links != nil {
log.Debugf("Received %d blocklists links", len(data.Links.Blocklists)) log.Debugf("Received %d blocklists links", len(data.Links.Blocklists))
} }
add_counters, delete_counters := makeAddAndDeleteCounters() addCounters, deleteCounters := makeAddAndDeleteCounters()
// process deleted decisions // process deleted decisions
if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, delete_counters); err != nil { if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters); err != nil {
return err return err
} else { } else {
log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted)
@ -610,28 +670,30 @@ func (a *apic) PullTop(forcePull bool) error {
alert := createAlertForDecision(decisions[0]) alert := createAlertForDecision(decisions[0])
alertsFromCapi := []*models.Alert{alert} alertsFromCapi := []*models.Alert{alert}
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters) alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)
err = a.SaveAlerts(alertsFromCapi, add_counters, delete_counters) err = a.SaveAlerts(alertsFromCapi, addCounters, deleteCounters)
if err != nil { if err != nil {
return fmt.Errorf("while saving alerts: %w", err) return fmt.Errorf("while saving alerts: %w", err)
} }
// update blocklists // update blocklists
if err := a.UpdateBlocklists(data.Links, add_counters, forcePull); err != nil { if err := a.UpdateBlocklists(data.Links, addCounters, forcePull); err != nil {
return fmt.Errorf("while updating blocklists: %w", err) return fmt.Errorf("while updating blocklists: %w", err)
} }
return nil return nil
} }
// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert // we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error { func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
add_counters, _ := makeAddAndDeleteCounters() addCounters, _ := makeAddAndDeleteCounters()
if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{ if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{blocklist}, Blocklists: []*modelscapi.BlocklistLink{blocklist},
}, add_counters, forcePull); err != nil { }, addCounters, forcePull); err != nil {
return fmt.Errorf("while pulling blocklist: %w", err) return fmt.Errorf("while pulling blocklist: %w", err)
} }
return nil return nil
} }
@ -641,17 +703,20 @@ func (a *apic) whitelistedBy(decision *models.Decision) string {
if decision.Value == nil { if decision.Value == nil {
return "" return ""
} }
ipval := net.ParseIP(*decision.Value) ipval := net.ParseIP(*decision.Value)
for _, cidr := range a.whitelists.Cidrs { for _, cidr := range a.whitelists.Cidrs {
if cidr.Contains(ipval) { if cidr.Contains(ipval) {
return cidr.String() return cidr.String()
} }
} }
for _, ip := range a.whitelists.Ips { for _, ip := range a.whitelists.Ips {
if ip != nil && ip.Equal(ipval) { if ip != nil && ip.Equal(ipval) {
return ip.String() return ip.String()
} }
} }
return "" return ""
} }
@ -661,12 +726,14 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis
} }
//deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place //deal with CAPI whitelists for fire. We want to avoid having a second list, so we shrink in place
outIdx := 0 outIdx := 0
for _, decision := range decisions { for _, decision := range decisions {
whitelister := a.whitelistedBy(decision) whitelister := a.whitelistedBy(decision)
if whitelister != "" { if whitelister != "" {
log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, whitelister) log.Infof("%s from %s is whitelisted by %s", *decision.Value, *decision.Scenario, whitelister)
continue continue
} }
decisions[outIdx] = decision decisions[outIdx] = decision
outIdx++ outIdx++
} }
@ -674,17 +741,20 @@ func (a *apic) ApplyApicWhitelists(decisions []*models.Decision) []*models.Decis
return decisions[:outIdx] return decisions[:outIdx]
} }
func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, add_counters map[string]map[string]int, delete_counters map[string]map[string]int) error { func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) error {
for _, alert := range alertsFromCapi { for _, alert := range alertsFromCapi {
setAlertScenario(alert, add_counters, delete_counters) setAlertScenario(alert, addCounters, deleteCounters)
log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions)) log.Debugf("%s has %d decisions", *alert.Source.Scope, len(alert.Decisions))
if a.dbClient.Type == "sqlite" && (a.dbClient.WalMode == nil || !*a.dbClient.WalMode) { if a.dbClient.Type == "sqlite" && (a.dbClient.WalMode == nil || !*a.dbClient.WalMode) {
log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist") log.Warningf("sqlite is not using WAL mode, LAPI might become unresponsive when inserting the community blocklist")
} }
alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert) alertID, inserted, deleted, err := a.dbClient.UpdateCommunityBlocklist(alert)
if err != nil { if err != nil {
return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err) return fmt.Errorf("while saving alert from %s: %w", *alert.Source.Scope, err)
} }
log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alert.Source.Scope, inserted, deleted, alertID) log.Printf("%s : added %d entries, deleted %d entries (alert:%d)", *alert.Source.Scope, inserted, deleted, alertID)
} }
@ -697,71 +767,91 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo
alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name))) alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name)))
alertQuery.Order(ent.Desc(alert.FieldCreatedAt)) alertQuery.Order(ent.Desc(alert.FieldCreatedAt))
alertInstance, err := alertQuery.First(context.Background()) alertInstance, err := alertQuery.First(context.Background())
if err != nil { if err != nil {
if ent.IsNotFound(err) { if ent.IsNotFound(err) {
log.Debugf("no alert found for %s, force refresh", *blocklist.Name) log.Debugf("no alert found for %s, force refresh", *blocklist.Name)
return true, nil return true, nil
} }
return false, fmt.Errorf("while getting alert: %w", err) return false, fmt.Errorf("while getting alert: %w", err)
} }
decisionQuery := a.dbClient.Ent.Decision.Query() decisionQuery := a.dbClient.Ent.Decision.Query()
decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID))) decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID)))
firstDecision, err := decisionQuery.First(context.Background()) firstDecision, err := decisionQuery.First(context.Background())
if err != nil { if err != nil {
if ent.IsNotFound(err) { if ent.IsNotFound(err) {
log.Debugf("no decision found for %s, force refresh", *blocklist.Name) log.Debugf("no decision found for %s, force refresh", *blocklist.Name)
return true, nil return true, nil
} }
return false, fmt.Errorf("while getting decision: %w", err) return false, fmt.Errorf("while getting decision: %w", err)
} }
if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) { if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) {
log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name) log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name)
return true, nil return true, nil
} }
return false, nil return false, nil
} }
func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int, forcePull bool) error { func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, addCounters map[string]map[string]int, forcePull bool) error {
if blocklist.Scope == nil { if blocklist.Scope == nil {
log.Warningf("blocklist has no scope") log.Warningf("blocklist has no scope")
return nil return nil
} }
if blocklist.Duration == nil { if blocklist.Duration == nil {
log.Warningf("blocklist has no duration") log.Warningf("blocklist has no duration")
return nil return nil
} }
if !forcePull { if !forcePull {
_forcePull, err := a.ShouldForcePullBlocklist(blocklist) _forcePull, err := a.ShouldForcePullBlocklist(blocklist)
if err != nil { if err != nil {
return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err) return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
} }
forcePull = _forcePull forcePull = _forcePull
} }
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name) blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
var lastPullTimestamp *string
var err error var (
lastPullTimestamp *string
err error
)
if !forcePull { if !forcePull {
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
if err != nil { if err != nil {
return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) return fmt.Errorf("while getting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
} }
} }
decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp) decisions, hasChanged, err := client.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
if err != nil { if err != nil {
return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err) return fmt.Errorf("while getting decisions from blocklist %s: %w", *blocklist.Name, err)
} }
if !hasChanged { if !hasChanged {
if lastPullTimestamp == nil { if lastPullTimestamp == nil {
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name) log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
} else { } else {
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp) log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
} }
return nil return nil
} }
err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
if err != nil { if err != nil {
return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err) return fmt.Errorf("while setting last pull timestamp for blocklist %s: %w", *blocklist.Name, err)
} }
if len(decisions) == 0 { if len(decisions) == 0 {
log.Infof("blocklist %s has no decisions", *blocklist.Name) log.Infof("blocklist %s has no decisions", *blocklist.Name)
return nil return nil
@ -770,19 +860,21 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
decisions = a.ApplyApicWhitelists(decisions) decisions = a.ApplyApicWhitelists(decisions)
alert := createAlertForDecision(decisions[0]) alert := createAlertForDecision(decisions[0])
alertsFromCapi := []*models.Alert{alert} alertsFromCapi := []*models.Alert{alert}
alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, add_counters) alertsFromCapi = fillAlertsWithDecisions(alertsFromCapi, decisions, addCounters)
err = a.SaveAlerts(alertsFromCapi, add_counters, nil) err = a.SaveAlerts(alertsFromCapi, addCounters, nil)
if err != nil { if err != nil {
return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err) return fmt.Errorf("while saving alert from blocklist %s: %w", *blocklist.Name, err)
} }
return nil return nil
} }
func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int, forcePull bool) error { func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, addCounters map[string]map[string]int, forcePull bool) error {
if links == nil { if links == nil {
return nil return nil
} }
if links.Blocklists == nil { if links.Blocklists == nil {
return nil return nil
} }
@ -792,21 +884,23 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
if err != nil { if err != nil {
return fmt.Errorf("while creating default client: %w", err) return fmt.Errorf("while creating default client: %w", err)
} }
for _, blocklist := range links.Blocklists { for _, blocklist := range links.Blocklists {
if err := a.updateBlocklist(defaultClient, blocklist, add_counters, forcePull); err != nil { if err := a.updateBlocklist(defaultClient, blocklist, addCounters, forcePull); err != nil {
return err return err
} }
} }
return nil return nil
} }
func setAlertScenario(alert *models.Alert, add_counters map[string]map[string]int, delete_counters map[string]map[string]int) { func setAlertScenario(alert *models.Alert, addCounters map[string]map[string]int, deleteCounters map[string]map[string]int) {
if *alert.Source.Scope == types.CAPIOrigin { if *alert.Source.Scope == types.CAPIOrigin {
*alert.Source.Scope = types.CommunityBlocklistPullSourceScope *alert.Source.Scope = types.CommunityBlocklistPullSourceScope
alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.CAPIOrigin]["all"], delete_counters[types.CAPIOrigin]["all"])) alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.CAPIOrigin]["all"], deleteCounters[types.CAPIOrigin]["all"]))
} else if *alert.Source.Scope == types.ListOrigin { } else if *alert.Source.Scope == types.ListOrigin {
*alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario) *alert.Source.Scope = fmt.Sprintf("%s:%s", types.ListOrigin, *alert.Scenario)
alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", add_counters[types.ListOrigin][*alert.Scenario], delete_counters[types.ListOrigin][*alert.Scenario])) alert.Scenario = ptr.Of(fmt.Sprintf("update : +%d/-%d IPs", addCounters[types.ListOrigin][*alert.Scenario], deleteCounters[types.ListOrigin][*alert.Scenario]))
} }
} }
@ -814,20 +908,26 @@ func (a *apic) Pull() error {
defer trace.CatchPanic("lapi/pullFromAPIC") defer trace.CatchPanic("lapi/pullFromAPIC")
toldOnce := false toldOnce := false
for { for {
scenario, err := a.FetchScenariosListFromDB() scenario, err := a.FetchScenariosListFromDB()
if err != nil { if err != nil {
log.Errorf("unable to fetch scenarios from db: %s", err) log.Errorf("unable to fetch scenarios from db: %s", err)
} }
if len(scenario) > 0 { if len(scenario) > 0 {
break break
} }
if !toldOnce { if !toldOnce {
log.Warning("scenario list is empty, will not pull yet") log.Warning("scenario list is empty, will not pull yet")
toldOnce = true toldOnce = true
} }
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
if err := a.PullTop(false); err != nil { if err := a.PullTop(false); err != nil {
log.Errorf("capi pull top: %s", err) log.Errorf("capi pull top: %s", err)
} }
@ -839,6 +939,7 @@ func (a *apic) Pull() error {
select { select {
case <-ticker.C: case <-ticker.C:
ticker.Reset(a.pullInterval) ticker.Reset(a.pullInterval)
if err := a.PullTop(false); err != nil { if err := a.PullTop(false); err != nil {
log.Errorf("capi pull top: %s", err) log.Errorf("capi pull top: %s", err)
continue continue
@ -846,6 +947,7 @@ func (a *apic) Pull() error {
case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others? case <-a.pullTomb.Dying(): // if one apic routine is dying, do we kill the others?
a.metricsTomb.Kill(nil) a.metricsTomb.Kill(nil)
a.pushTomb.Kill(nil) a.pushTomb.Kill(nil)
return nil return nil
} }
} }
@ -858,15 +960,15 @@ func (a *apic) Shutdown() {
} }
func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[string]int) { func makeAddAndDeleteCounters() (map[string]map[string]int, map[string]map[string]int) {
add_counters := make(map[string]map[string]int) addCounters := make(map[string]map[string]int)
add_counters[types.CAPIOrigin] = make(map[string]int) addCounters[types.CAPIOrigin] = make(map[string]int)
add_counters[types.ListOrigin] = make(map[string]int) addCounters[types.ListOrigin] = make(map[string]int)
delete_counters := make(map[string]map[string]int) deleteCounters := make(map[string]map[string]int)
delete_counters[types.CAPIOrigin] = make(map[string]int) deleteCounters[types.CAPIOrigin] = make(map[string]int)
delete_counters[types.ListOrigin] = make(map[string]int) deleteCounters[types.ListOrigin] = make(map[string]int)
return add_counters, delete_counters return addCounters, deleteCounters
} }
func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) { func updateCounterForDecision(counter map[string]map[string]int, origin *string, scenario *string, totalDecisions int) {

View file

@ -2,10 +2,10 @@ package apiserver
import ( import (
"context" "context"
"slices"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"slices"
"github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/go-cs-lib/trace"
@ -66,6 +66,7 @@ func (a *apic) fetchMachineIDs() ([]string, error) {
} }
// sorted slices are required for the slices.Equal comparison // sorted slices are required for the slices.Equal comparison
slices.Sort(ret) slices.Sort(ret)
return ret, nil return ret, nil
} }
@ -91,6 +92,7 @@ func (a *apic) SendMetrics(stop chan (bool)) {
if count < len(metInts)-1 { if count < len(metInts)-1 {
count++ count++
} }
return metInts[count] return metInts[count]
} }
@ -100,8 +102,10 @@ func (a *apic) SendMetrics(stop chan (bool)) {
ids, err := a.fetchMachineIDs() ids, err := a.fetchMachineIDs()
if err != nil { if err != nil {
log.Debugf("unable to get machines (%s), will retry", err) log.Debugf("unable to get machines (%s), will retry", err)
return return
} }
machineIDs = ids machineIDs = ids
} }
@ -117,16 +121,20 @@ func (a *apic) SendMetrics(stop chan (bool)) {
case <-stop: case <-stop:
checkTicker.Stop() checkTicker.Stop()
metTicker.Stop() metTicker.Stop()
return return
case <-checkTicker.C: case <-checkTicker.C:
oldIDs := machineIDs oldIDs := machineIDs
reloadMachineIDs() reloadMachineIDs()
if !slices.Equal(oldIDs, machineIDs) { if !slices.Equal(oldIDs, machineIDs) {
log.Infof("capi metrics: machines changed, immediate send") log.Infof("capi metrics: machines changed, immediate send")
metTicker.Reset(1 * time.Millisecond) metTicker.Reset(1 * time.Millisecond)
} }
case <-metTicker.C: case <-metTicker.C:
metTicker.Stop() metTicker.Stop()
metrics, err := a.GetMetrics() metrics, err := a.GetMetrics()
if err != nil { if err != nil {
log.Errorf("unable to get metrics (%s)", err) log.Errorf("unable to get metrics (%s)", err)
@ -134,17 +142,20 @@ func (a *apic) SendMetrics(stop chan (bool)) {
// metrics are nil if they could not be retrieved // metrics are nil if they could not be retrieved
if metrics != nil { if metrics != nil {
log.Info("capi metrics: sending") log.Info("capi metrics: sending")
_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics) _, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
if err != nil { if err != nil {
log.Errorf("capi metrics: failed: %s", err) log.Errorf("capi metrics: failed: %s", err)
} }
} }
metTicker.Reset(nextMetInt()) metTicker.Reset(nextMetInt())
case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others? case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others?
checkTicker.Stop() checkTicker.Stop()
metTicker.Stop() metTicker.Stop()
a.pullTomb.Kill(nil) a.pullTomb.Kill(nil)
a.pushTomb.Kill(nil) a.pushTomb.Kill(nil)
return return
} }
} }

View file

@ -61,6 +61,7 @@ func TestAPICSendMetrics(t *testing.T) {
httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{})) httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{}))
httpmock.Activate() httpmock.Activate()
defer httpmock.Deactivate() defer httpmock.Deactivate()
for _, tc := range tests { for _, tc := range tests {

View file

@ -44,12 +44,14 @@ func getDBClient(t *testing.T) *database.Client {
DbPath: dbPath.Name(), DbPath: dbPath.Name(),
}) })
require.NoError(t, err) require.NoError(t, err)
return dbClient return dbClient
} }
func getAPIC(t *testing.T) *apic { func getAPIC(t *testing.T) *apic {
t.Helper() t.Helper()
dbClient := getDBClient(t) dbClient := getDBClient(t)
return &apic{ return &apic{
AlertsAddChan: make(chan []*models.Alert), AlertsAddChan: make(chan []*models.Alert),
//DecisionDeleteChan: make(chan []*models.Decision), //DecisionDeleteChan: make(chan []*models.Decision),
@ -74,6 +76,7 @@ func absDiff(a int, b int) (c int) {
if c = a - b; c < 0 { if c = a - b; c < 0 {
return -1 * c return -1 * c
} }
return c return c
} }
@ -94,6 +97,7 @@ func jsonMarshalX(v interface{}) []byte {
if err != nil { if err != nil {
panic(err) panic(err)
} }
return data return data
} }
@ -176,7 +180,6 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
assert.ElementsMatch(t, tc.expectedScenarios, scenarios) assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
}) })
} }
} }
@ -220,6 +223,7 @@ func TestNewAPIC(t *testing.T) {
expectedErr: "first path segment in URL cannot contain colon", expectedErr: "first path segment in URL cannot contain colon",
}, },
} }
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -274,7 +278,7 @@ func TestAPICHandleDeletedDecisions(t *testing.T) {
Scope: ptr.Of("IP"), Scope: ptr.Of("IP"),
}}, deleteCounters) }}, deleteCounters)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, nbDeleted) assert.Equal(t, 2, nbDeleted)
assert.Equal(t, 2, deleteCounters[types.CAPIOrigin]["all"]) assert.Equal(t, 2, deleteCounters[types.CAPIOrigin]["all"])
} }
@ -338,6 +342,7 @@ func TestAPICGetMetrics(t *testing.T) {
}, },
}, },
} }
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -394,6 +399,7 @@ func TestCreateAlertsForDecision(t *testing.T) {
Origin: ptr.Of(types.CAPIOrigin), Origin: ptr.Of(types.CAPIOrigin),
Scenario: ptr.Of("crowdsecurity/ssh-bf"), Scenario: ptr.Of("crowdsecurity/ssh-bf"),
} }
type args struct { type args struct {
decisions []*models.Decision decisions []*models.Decision
} }
@ -443,6 +449,7 @@ func TestCreateAlertsForDecision(t *testing.T) {
}, },
}, },
} }
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -477,6 +484,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
Scenario: ptr.Of("crowdsecurity/ssh-bf"), Scenario: ptr.Of("crowdsecurity/ssh-bf"),
Scope: ptr.Of("ip"), Scope: ptr.Of("ip"),
} }
type args struct { type args struct {
alerts []*models.Alert alerts []*models.Alert
decisions []*models.Decision decisions []*models.Decision
@ -520,6 +528,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
}, },
}, },
} }
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -546,12 +555,14 @@ func TestAPICWhitelists(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unable to parse cidr : %s", err) t.Fatalf("unable to parse cidr : %s", err)
} }
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
cidrwl1 = "11.2.3.0/24" cidrwl1 = "11.2.3.0/24"
_, tnet, err = net.ParseCIDR(cidrwl1) _, tnet, err = net.ParseCIDR(cidrwl1)
if err != nil { if err != nil {
t.Fatalf("unable to parse cidr : %s", err) t.Fatalf("unable to parse cidr : %s", err)
} }
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
api.dbClient.Ent.Decision.Create(). api.dbClient.Ent.Decision.Create().
SetOrigin(types.CAPIOrigin). SetOrigin(types.CAPIOrigin).
@ -564,6 +575,7 @@ func TestAPICWhitelists(t *testing.T) {
assertTotalDecisionCount(t, api.dbClient, 1) assertTotalDecisionCount(t, api.dbClient, 1)
assertTotalValidDecisionCount(t, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1)
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
200, jsonMarshalX( 200, jsonMarshalX(
@ -681,33 +693,39 @@ func TestAPICWhitelists(t *testing.T) {
AllX(context.Background()) AllX(context.Background())
decisionScenarioFreq := make(map[string]int) decisionScenarioFreq := make(map[string]int)
decisionIp := make(map[string]int) decisionIP := make(map[string]int)
alertScenario := make(map[string]int) alertScenario := make(map[string]int)
for _, alert := range alerts { for _, alert := range alerts {
alertScenario[alert.SourceScope]++ alertScenario[alert.SourceScope]++
} }
assert.Equal(t, 3, len(alertScenario))
assert.Len(t, alertScenario, 3)
assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope]) assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope])
assert.Equal(t, 1, alertScenario["lists:blocklist1"]) assert.Equal(t, 1, alertScenario["lists:blocklist1"])
assert.Equal(t, 1, alertScenario["lists:blocklist2"]) assert.Equal(t, 1, alertScenario["lists:blocklist2"])
for _, decisions := range validDecisions { for _, decisions := range validDecisions {
decisionScenarioFreq[decisions.Scenario]++ decisionScenarioFreq[decisions.Scenario]++
decisionIp[decisions.Value]++ decisionIP[decisions.Value]++
} }
assert.Equal(t, 1, decisionIp["2.2.3.4"], 1)
assert.Equal(t, 1, decisionIp["6.2.3.4"], 1) assert.Equal(t, 1, decisionIP["2.2.3.4"], 1)
if _, ok := decisionIp["13.2.3.4"]; ok { assert.Equal(t, 1, decisionIP["6.2.3.4"], 1)
if _, ok := decisionIP["13.2.3.4"]; ok {
t.Errorf("13.2.3.4 is whitelisted") t.Errorf("13.2.3.4 is whitelisted")
} }
if _, ok := decisionIp["13.2.3.5"]; ok {
if _, ok := decisionIP["13.2.3.5"]; ok {
t.Errorf("13.2.3.5 is whitelisted") t.Errorf("13.2.3.5 is whitelisted")
} }
if _, ok := decisionIp["9.2.3.4"]; ok {
if _, ok := decisionIP["9.2.3.4"]; ok {
t.Errorf("9.2.3.4 is whitelisted") t.Errorf("9.2.3.4 is whitelisted")
} }
assert.Equal(t, 1, decisionScenarioFreq["blocklist1"], 1) assert.Equal(t, 1, decisionScenarioFreq["blocklist1"], 1)
assert.Equal(t, 1, decisionScenarioFreq["blocklist2"], 1) assert.Equal(t, 1, decisionScenarioFreq["blocklist2"], 1)
assert.Equal(t, 2, decisionScenarioFreq["crowdsecurity/test1"], 2) assert.Equal(t, 2, decisionScenarioFreq["crowdsecurity/test1"], 2)
@ -726,6 +744,7 @@ func TestAPICPullTop(t *testing.T) {
assertTotalDecisionCount(t, api.dbClient, 1) assertTotalDecisionCount(t, api.dbClient, 1)
assertTotalValidDecisionCount(t, api.dbClient, 1) assertTotalValidDecisionCount(t, api.dbClient, 1)
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
200, jsonMarshalX( 200, jsonMarshalX(
@ -817,7 +836,8 @@ func TestAPICPullTop(t *testing.T) {
for _, alert := range alerts { for _, alert := range alerts {
alertScenario[alert.SourceScope]++ alertScenario[alert.SourceScope]++
} }
assert.Equal(t, 3, len(alertScenario))
assert.Len(t, alertScenario, 3)
assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope]) assert.Equal(t, 1, alertScenario[types.CommunityBlocklistPullSourceScope])
assert.Equal(t, 1, alertScenario["lists:blocklist1"]) assert.Equal(t, 1, alertScenario["lists:blocklist1"])
assert.Equal(t, 1, alertScenario["lists:blocklist2"]) assert.Equal(t, 1, alertScenario["lists:blocklist2"])
@ -835,6 +855,7 @@ func TestAPICPullTop(t *testing.T) {
func TestAPICPullTopBLCacheFirstCall(t *testing.T) { func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
// no decision in db, no last modified parameter. // no decision in db, no last modified parameter.
api := getAPIC(t) api := getAPIC(t)
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
@ -904,6 +925,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
func TestAPICPullTopBLCacheForceCall(t *testing.T) { func TestAPICPullTopBLCacheForceCall(t *testing.T) {
api := getAPIC(t) api := getAPIC(t)
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()
// create a decision about to expire. It should force fetch // create a decision about to expire. It should force fetch
@ -975,6 +997,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
func TestAPICPullBlocklistCall(t *testing.T) { func TestAPICPullBlocklistCall(t *testing.T) {
api := getAPIC(t) api := getAPIC(t)
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()

View file

@ -2,8 +2,6 @@ package apiserver
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -15,7 +13,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-co-op/gocron" "github.com/go-co-op/gocron"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
@ -23,7 +20,6 @@ import (
"github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/go-cs-lib/trace"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers" "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers"
v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
@ -32,9 +28,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
) )
var ( const keyLength = 32
keyLength = 32
)
type APIServer struct { type APIServer struct {
URL string URL string
@ -52,57 +46,117 @@ type APIServer struct {
isEnrolled bool isEnrolled bool
} }
// RecoveryWithWriter returns a middleware for a given writer that recovers from any panics and writes a 500 if there was one. func recoverFromPanic(c *gin.Context) {
err := recover()
if err == nil {
return
}
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
brokenPipe := false
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
brokenPipe = true
}
}
}
// because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go
// and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them
if strErr, ok := err.(error); ok {
//stolen from http2/server.go in x/net
var (
errClientDisconnected = errors.New("client disconnected")
errClosedBody = errors.New("body closed by handler")
errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
errStreamClosed = errors.New("http2: stream closed")
)
if errors.Is(strErr, errClientDisconnected) ||
errors.Is(strErr, errClosedBody) ||
errors.Is(strErr, errHandlerComplete) ||
errors.Is(strErr, errStreamClosed) {
brokenPipe = true
}
}
if brokenPipe {
log.Warningf("client %s disconnected : %s", c.ClientIP(), err)
c.Abort()
} else {
filename := trace.WriteStackTrace(err)
log.Warningf("client %s error : %s", c.ClientIP(), err)
log.Warningf("stacktrace written to %s, please join to your issue", filename)
c.AbortWithStatus(http.StatusInternalServerError)
}
}
// CustomRecoveryWithWriter returns a middleware for a writer that recovers from any panics and writes a 500 if there was one.
func CustomRecoveryWithWriter() gin.HandlerFunc { func CustomRecoveryWithWriter() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
defer func() { defer recoverFromPanic(c)
if err := recover(); err != nil {
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
brokenPipe = true
}
}
}
// because of https://github.com/golang/net/blob/39120d07d75e76f0079fe5d27480bcb965a21e4c/http2/server.go
// and because it seems gin doesn't handle those neither, we need to "hand define" some errors to properly catch them
if strErr, ok := err.(error); ok {
//stolen from http2/server.go in x/net
var (
errClientDisconnected = errors.New("client disconnected")
errClosedBody = errors.New("body closed by handler")
errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
errStreamClosed = errors.New("http2: stream closed")
)
if errors.Is(strErr, errClientDisconnected) ||
errors.Is(strErr, errClosedBody) ||
errors.Is(strErr, errHandlerComplete) ||
errors.Is(strErr, errStreamClosed) {
brokenPipe = true
}
}
if brokenPipe {
log.Warningf("client %s disconnected : %s", c.ClientIP(), err)
c.Abort()
} else {
filename := trace.WriteStackTrace(err)
log.Warningf("client %s error : %s", c.ClientIP(), err)
log.Warningf("stacktrace written to %s, please join to your issue", filename)
c.AbortWithStatus(http.StatusInternalServerError)
}
}
}()
c.Next() c.Next()
} }
} }
// XXX: could be a method of LocalApiServerCfg
func newGinLogger(config *csconfig.LocalApiServerCfg) (*log.Logger, string, error) {
clog := log.New()
if err := types.ConfigureLogger(clog); err != nil {
return nil, "", fmt.Errorf("while configuring gin logger: %w", err)
}
if config.LogLevel != nil {
clog.SetLevel(*config.LogLevel)
}
if config.LogMedia != "file" {
return clog, "", nil
}
// Log rotation
logFile := filepath.Join(config.LogDir, "crowdsec_api.log")
log.Debugf("starting router, logging to %s", logFile)
logger := &lumberjack.Logger{
Filename: logFile,
MaxSize: 500, //megabytes
MaxBackups: 3,
MaxAge: 28, //days
Compress: true, //disabled by default
}
if config.LogMaxSize != 0 {
logger.MaxSize = config.LogMaxSize
}
if config.LogMaxFiles != 0 {
logger.MaxBackups = config.LogMaxFiles
}
if config.LogMaxAge != 0 {
logger.MaxAge = config.LogMaxAge
}
if config.CompressLogs != nil {
logger.Compress = *config.CompressLogs
}
clog.SetOutput(logger)
return clog, logFile, nil
}
// NewServer creates a LAPI server.
// It sets up a gin router, a database client, and a controller.
func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
var flushScheduler *gocron.Scheduler var flushScheduler *gocron.Scheduler
dbClient, err := database.NewClient(config.DbConfig) dbClient, err := database.NewClient(config.DbConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to init database client: %w", err) return nil, fmt.Errorf("unable to init database client: %w", err)
@ -115,63 +169,26 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
} }
} }
logFile := ""
if config.LogMedia == "file" {
logFile = filepath.Join(config.LogDir, "crowdsec_api.log")
}
if log.GetLevel() < log.DebugLevel { if log.GetLevel() < log.DebugLevel {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
log.Debugf("starting router, logging to %s", logFile)
router := gin.New() router := gin.New()
router.ForwardedByClientIP = false
if config.TrustedProxies != nil && config.UseForwardedForHeaders { if config.TrustedProxies != nil && config.UseForwardedForHeaders {
if err := router.SetTrustedProxies(*config.TrustedProxies); err != nil { if err = router.SetTrustedProxies(*config.TrustedProxies); err != nil {
return nil, fmt.Errorf("while setting trusted_proxies: %w", err) return nil, fmt.Errorf("while setting trusted_proxies: %w", err)
} }
router.ForwardedByClientIP = true router.ForwardedByClientIP = true
} else {
router.ForwardedByClientIP = false
} }
/*The logger that will be used by handlers*/ // The logger that will be used by handlers
clog := log.New() clog, logFile, err := newGinLogger(config)
if err != nil {
if err := types.ConfigureLogger(clog); err != nil { return nil, err
return nil, fmt.Errorf("while configuring gin logger: %w", err)
}
if config.LogLevel != nil {
clog.SetLevel(*config.LogLevel)
}
/*Configure logs*/
if logFile != "" {
_maxsize := 500
if config.LogMaxSize != 0 {
_maxsize = config.LogMaxSize
}
_maxfiles := 3
if config.LogMaxFiles != 0 {
_maxfiles = config.LogMaxFiles
}
_maxage := 28
if config.LogMaxAge != 0 {
_maxage = config.LogMaxAge
}
_compress := true
if config.CompressLogs != nil {
_compress = *config.CompressLogs
}
LogOutput := &lumberjack.Logger{
Filename: logFile,
MaxSize: _maxsize, //megabytes
MaxBackups: _maxfiles,
MaxAge: _maxage, //days
Compress: _compress, //disabled by default
}
clog.SetOutput(LogOutput)
} }
gin.DefaultErrorWriter = clog.WriterLevel(log.ErrorLevel) gin.DefaultErrorWriter = clog.WriterLevel(log.ErrorLevel)
@ -206,41 +223,50 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
DisableRemoteLapiRegistration: config.DisableRemoteLapiRegistration, DisableRemoteLapiRegistration: config.DisableRemoteLapiRegistration,
} }
var apiClient *apic var (
var papiClient *Papi apiClient *apic
var isMachineEnrolled = false papiClient *Papi
isMachineEnrolled = false
)
controller.AlertsAddChan = nil
controller.DecisionDeleteChan = nil
if config.OnlineClient != nil && config.OnlineClient.Credentials != nil { if config.OnlineClient != nil && config.OnlineClient.Credentials != nil {
log.Printf("Loading CAPI manager") log.Printf("Loading CAPI manager")
apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Infof("CAPI manager configured successfully") log.Infof("CAPI manager configured successfully")
isMachineEnrolled = isEnrolled(apiClient.apiClient)
controller.AlertsAddChan = apiClient.AlertsAddChan controller.AlertsAddChan = apiClient.AlertsAddChan
if isMachineEnrolled {
if apiClient.apiClient.IsEnrolled() {
isMachineEnrolled = true
log.Infof("Machine is enrolled in the console, Loading PAPI Client") log.Infof("Machine is enrolled in the console, Loading PAPI Client")
papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel) papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel)
if err != nil { if err != nil {
return nil, err return nil, err
} }
controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel
} else { } else {
log.Errorf("Machine is not enrolled in the console, can't synchronize with the console") log.Errorf("Machine is not enrolled in the console, can't synchronize with the console")
} }
} else {
apiClient = nil
controller.AlertsAddChan = nil
controller.DecisionDeleteChan = nil
} }
if trustedIPs, err := config.GetTrustedIPs(); err == nil { trustedIPs, err := config.GetTrustedIPs()
controller.TrustedIPs = trustedIPs if err != nil {
} else {
return nil, err return nil, err
} }
controller.TrustedIPs = trustedIPs
return &APIServer{ return &APIServer{
URL: config.ListenURI, URL: config.ListenURI,
TLS: config.TLS, TLS: config.TLS,
@ -255,80 +281,20 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
consoleConfig: config.ConsoleConfig, consoleConfig: config.ConsoleConfig,
isEnrolled: isMachineEnrolled, isEnrolled: isMachineEnrolled,
}, nil }, nil
}
func isEnrolled(client *apiclient.ApiClient) bool {
apiHTTPClient := client.GetClient()
jwtTransport := apiHTTPClient.Transport.(*apiclient.JWTTransport)
tokenStr := jwtTransport.Token
token, _ := jwt.Parse(tokenStr, nil)
if token == nil {
return false
}
claims := token.Claims.(jwt.MapClaims)
_, ok := claims["organization_id"]
return ok
} }
func (s *APIServer) Router() (*gin.Engine, error) { func (s *APIServer) Router() (*gin.Engine, error) {
return s.router, nil return s.router, nil
} }
func (s *APIServer) GetTLSConfig() (*tls.Config, error) {
var caCert []byte
var err error
var caCertPool *x509.CertPool
var clientAuthType tls.ClientAuthType
if s.TLS == nil {
return &tls.Config{}, nil
}
if s.TLS.ClientVerification == "" {
//sounds like a sane default : verify client cert if given, but don't make it mandatory
clientAuthType = tls.VerifyClientCertIfGiven
} else {
clientAuthType, err = getTLSAuthType(s.TLS.ClientVerification)
if err != nil {
return nil, err
}
}
if s.TLS.CACertPath != "" {
if clientAuthType > tls.RequestClientCert {
log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String())
caCert, err = os.ReadFile(s.TLS.CACertPath)
if err != nil {
return nil, fmt.Errorf("while opening cert file: %w", err)
}
caCertPool, err = x509.SystemCertPool()
if err != nil {
log.Warnf("Error loading system CA certificates: %s", err)
}
if caCertPool == nil {
caCertPool = x509.NewCertPool()
}
caCertPool.AppendCertsFromPEM(caCert)
}
}
return &tls.Config{
ServerName: s.TLS.ServerName, //should it be removed ?
ClientAuth: clientAuthType,
ClientCAs: caCertPool,
MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details
}, nil
}
func (s *APIServer) Run(apiReady chan bool) error { func (s *APIServer) Run(apiReady chan bool) error {
defer trace.CatchPanic("lapi/runServer") defer trace.CatchPanic("lapi/runServer")
tlsCfg, err := s.GetTLSConfig()
tlsCfg, err := s.TLS.GetTLSConfig()
if err != nil { if err != nil {
return fmt.Errorf("while creating TLS config: %w", err) return fmt.Errorf("while creating TLS config: %w", err)
} }
s.httpServer = &http.Server{ s.httpServer = &http.Server{
Addr: s.URL, Addr: s.URL,
Handler: s.router, Handler: s.router,
@ -386,41 +352,74 @@ func (s *APIServer) Run(apiReady chan bool) error {
}) })
} }
s.httpServerTomb.Go(func() error { s.httpServerTomb.Go(func() error { s.listenAndServeURL(apiReady); return nil })
go func() {
apiReady <- true
log.Infof("CrowdSec Local API listening on %s", s.URL)
if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") {
if s.TLS.KeyFilePath == "" {
log.Fatalf("while serving local API: %v", errors.New("missing TLS key file"))
} else if s.TLS.CertFilePath == "" {
log.Fatalf("while serving local API: %v", errors.New("missing TLS cert file"))
}
if err := s.httpServer.ListenAndServeTLS(s.TLS.CertFilePath, s.TLS.KeyFilePath); err != nil {
log.Fatalf("while serving local API: %v", err)
}
} else {
if err := s.httpServer.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("while serving local API: %v", err)
}
}
}()
<-s.httpServerTomb.Dying()
return nil
})
return nil return nil
} }
// listenAndServeURL starts the http server and blocks until it's closed
// it also updates the URL field with the actual address the server is listening on
// it's meant to be run in a separate goroutine
func (s *APIServer) listenAndServeURL(apiReady chan bool) {
serverError := make(chan error, 1)
go func() {
listener, err := net.Listen("tcp", s.URL)
if err != nil {
serverError <- fmt.Errorf("listening on %s: %w", s.URL, err)
return
}
s.URL = listener.Addr().String()
log.Infof("CrowdSec Local API listening on %s", s.URL)
apiReady <- true
if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") {
if s.TLS.KeyFilePath == "" {
serverError <- errors.New("missing TLS key file")
return
} else if s.TLS.CertFilePath == "" {
serverError <- errors.New("missing TLS cert file")
return
}
err = s.httpServer.ServeTLS(listener, s.TLS.CertFilePath, s.TLS.KeyFilePath)
} else {
err = s.httpServer.Serve(listener)
}
if err != nil && err != http.ErrServerClosed {
serverError <- fmt.Errorf("while serving local API: %w", err)
return
}
}()
select {
case err := <-serverError:
log.Fatalf("while starting API server: %s", err)
case <-s.httpServerTomb.Dying():
log.Infof("Shutting down API server")
// do we need a graceful shutdown here?
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.httpServer.Shutdown(ctx); err != nil {
log.Errorf("while shutting down http server: %s", err)
}
}
}
func (s *APIServer) Close() { func (s *APIServer) Close() {
if s.apic != nil { if s.apic != nil {
s.apic.Shutdown() // stop apic first since it use dbClient s.apic.Shutdown() // stop apic first since it use dbClient
} }
if s.papi != nil { if s.papi != nil {
s.papi.Shutdown() // papi also uses the dbClient s.papi.Shutdown() // papi also uses the dbClient
} }
s.dbClient.Ent.Close() s.dbClient.Ent.Close()
if s.flushScheduler != nil { if s.flushScheduler != nil {
s.flushScheduler.Stop() s.flushScheduler.Stop()
} }
@ -428,6 +427,7 @@ func (s *APIServer) Close() {
func (s *APIServer) Shutdown() error { func (s *APIServer) Shutdown() error {
s.Close() s.Close()
if s.httpServer != nil { if s.httpServer != nil {
if err := s.httpServer.Shutdown(context.TODO()); err != nil { if err := s.httpServer.Shutdown(context.TODO()); err != nil {
return err return err
@ -438,13 +438,17 @@ func (s *APIServer) Shutdown() error {
if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok { if pipe, ok := gin.DefaultErrorWriter.(*io.PipeWriter); ok {
pipe.Close() pipe.Close()
} }
if pipe, ok := gin.DefaultWriter.(*io.PipeWriter); ok { if pipe, ok := gin.DefaultWriter.(*io.PipeWriter); ok {
pipe.Close() pipe.Close()
} }
s.httpServerTomb.Kill(nil) s.httpServerTomb.Kill(nil)
if err := s.httpServerTomb.Wait(); err != nil { if err := s.httpServerTomb.Wait(); err != nil {
return fmt.Errorf("while waiting on httpServerTomb: %w", err) return fmt.Errorf("while waiting on httpServerTomb: %w", err)
} }
return nil return nil
} }
@ -453,36 +457,41 @@ func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) {
} }
func (s *APIServer) InitController() error { func (s *APIServer) InitController() error {
err := s.controller.Init() err := s.controller.Init()
if err != nil { if err != nil {
return fmt.Errorf("controller init: %w", err) return fmt.Errorf("controller init: %w", err)
} }
if s.TLS != nil {
var cacheExpiration time.Duration if s.TLS == nil {
if s.TLS.CacheExpiration != nil { return nil
cacheExpiration = *s.TLS.CacheExpiration
} else {
cacheExpiration = time.Hour
}
s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath,
cacheExpiration,
log.WithFields(log.Fields{
"component": "tls-auth",
"type": "agent",
}))
if err != nil {
return fmt.Errorf("while creating TLS auth for agents: %w", err)
}
s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath,
cacheExpiration,
log.WithFields(log.Fields{
"component": "tls-auth",
"type": "bouncer",
}))
if err != nil {
return fmt.Errorf("while creating TLS auth for bouncers: %w", err)
}
} }
return err
// TLS is configured: create the TLSAuth middleware for agents and bouncers
cacheExpiration := time.Hour
if s.TLS.CacheExpiration != nil {
cacheExpiration = *s.TLS.CacheExpiration
}
s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath,
cacheExpiration,
log.WithFields(log.Fields{
"component": "tls-auth",
"type": "agent",
}))
if err != nil {
return fmt.Errorf("while creating TLS auth for agents: %w", err)
}
s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath,
cacheExpiration,
log.WithFields(log.Fields{
"component": "tls-auth",
"type": "bouncer",
}))
if err != nil {
return fmt.Errorf("while creating TLS auth for bouncers: %w", err)
}
return nil
} }

View file

@ -11,21 +11,20 @@ import (
"testing" "testing"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/go-cs-lib/version"
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/go-openapi/strfmt"
"github.com/pkg/errors"
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/gin-gonic/gin" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
) )
var testMachineID = "test" var testMachineID = "test"
@ -46,6 +45,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config {
} }
tempDir, _ := os.MkdirTemp("", "crowdsec_tests") tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
t.Cleanup(func() { os.RemoveAll(tempDir) }) t.Cleanup(func() { os.RemoveAll(tempDir) })
dbconfig := csconfig.DatabaseCfg{ dbconfig := csconfig.DatabaseCfg{
@ -70,6 +70,7 @@ func LoadTestConfig(t *testing.T) csconfig.Config {
if err := config.API.Server.LoadProfiles(); err != nil { if err := config.API.Server.LoadProfiles(); err != nil {
log.Fatalf("failed to load profiles: %s", err) log.Fatalf("failed to load profiles: %s", err)
} }
return config return config
} }
@ -81,6 +82,7 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
} }
tempDir, _ := os.MkdirTemp("", "crowdsec_tests") tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
t.Cleanup(func() { os.RemoveAll(tempDir) }) t.Cleanup(func() { os.RemoveAll(tempDir) })
dbconfig := csconfig.DatabaseCfg{ dbconfig := csconfig.DatabaseCfg{
@ -107,18 +109,22 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
if err := config.API.Server.LoadProfiles(); err != nil { if err := config.API.Server.LoadProfiles(); err != nil {
log.Fatalf("failed to load profiles: %s", err) log.Fatalf("failed to load profiles: %s", err)
} }
return config return config
} }
func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) { func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) {
config := LoadTestConfig(t) config := LoadTestConfig(t)
os.Remove("./ent") os.Remove("./ent")
apiServer, err := NewServer(config.API.Server) apiServer, err := NewServer(config.API.Server)
if err != nil { if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err) return nil, config, fmt.Errorf("unable to run local API: %s", err)
} }
log.Printf("Creating new API server") log.Printf("Creating new API server")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
return apiServer, config, nil return apiServer, config, nil
} }
@ -135,6 +141,7 @@ func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config, error) {
if err != nil { if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err) return nil, config, fmt.Errorf("unable to run local API: %s", err)
} }
return router, config, nil return router, config, nil
} }
@ -150,12 +157,14 @@ func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config, error)
if err != nil { if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err) return nil, config, fmt.Errorf("unable to run local API: %s", err)
} }
log.Printf("Creating new API server") log.Printf("Creating new API server")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
router, err := apiServer.Router() router, err := apiServer.Router()
if err != nil { if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err) return nil, config, fmt.Errorf("unable to run local API: %s", err)
} }
return router, config, nil return router, config, nil
} }
@ -164,9 +173,11 @@ func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error {
if err != nil { if err != nil {
return fmt.Errorf("unable to create new database client: %s", err) return fmt.Errorf("unable to create new database client: %s", err)
} }
if err := dbClient.ValidateMachine(machineID); err != nil { if err := dbClient.ValidateMachine(machineID); err != nil {
return fmt.Errorf("unable to validate machine: %s", err) return fmt.Errorf("unable to validate machine: %s", err)
} }
return nil return nil
} }
@ -179,23 +190,24 @@ func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error
if err != nil { if err != nil {
return "", fmt.Errorf("Unable to list machines: %s", err) return "", fmt.Errorf("Unable to list machines: %s", err)
} }
for _, machine := range machines { for _, machine := range machines {
if machine.MachineId == machineID { if machine.MachineId == machineID {
return machine.IpAddress, nil return machine.IpAddress, nil
} }
} }
return "", nil return "", nil
} }
func GetAlertReaderFromFile(path string) *strings.Reader { func GetAlertReaderFromFile(path string) *strings.Reader {
alertContentBytes, err := os.ReadFile(path) alertContentBytes, err := os.ReadFile(path)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
alerts := make([]*models.Alert, 0) alerts := make([]*models.Alert, 0)
if err := json.Unmarshal(alertContentBytes, &alerts); err != nil { if err = json.Unmarshal(alertContentBytes, &alerts); err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -208,12 +220,13 @@ func GetAlertReaderFromFile(path string) *strings.Reader {
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
return strings.NewReader(string(alertContent))
return strings.NewReader(string(alertContent))
} }
func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) { func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) {
var response []*models.Decision var response []*models.Decision
if resp == nil { if resp == nil {
return nil, 0, errors.New("response is nil") return nil, 0, errors.New("response is nil")
} }
@ -221,11 +234,13 @@ func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision,
if err != nil { if err != nil {
return nil, resp.Code, err return nil, resp.Code, err
} }
return response, resp.Code, nil return response, resp.Code, nil
} }
func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) { func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) {
var response map[string]string var response map[string]string
if resp == nil { if resp == nil {
return nil, 0, errors.New("response is nil") return nil, 0, errors.New("response is nil")
} }
@ -233,11 +248,13 @@ func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string,
if err != nil { if err != nil {
return nil, resp.Code, err return nil, resp.Code, err
} }
return response, resp.Code, nil return response, resp.Code, nil
} }
func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) { func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) {
var response models.DeleteDecisionResponse var response models.DeleteDecisionResponse
if resp == nil { if resp == nil {
return nil, 0, errors.New("response is nil") return nil, 0, errors.New("response is nil")
} }
@ -245,11 +262,13 @@ func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDec
if err != nil { if err != nil {
return nil, resp.Code, err return nil, resp.Code, err
} }
return &response, resp.Code, nil return &response, resp.Code, nil
} }
func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) { func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) {
response := make(map[string][]*models.Decision) response := make(map[string][]*models.Decision)
if resp == nil { if resp == nil {
return nil, 0, errors.New("response is nil") return nil, 0, errors.New("response is nil")
} }
@ -257,6 +276,7 @@ func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*mod
if err != nil { if err != nil {
return nil, resp.Code, err return nil, resp.Code, err
} }
return response, resp.Code, nil return response, resp.Code, nil
} }
@ -271,6 +291,7 @@ func CreateTestMachine(router *gin.Engine) (string, error) {
req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body))
req.Header.Set("User-Agent", UserAgent) req.Header.Set("User-Agent", UserAgent)
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
return body, nil return body, nil
} }
@ -279,10 +300,12 @@ func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) {
if err != nil { if err != nil {
log.Fatalf("unable to create new database client: %s", err) log.Fatalf("unable to create new database client: %s", err)
} }
apiKey, err := middlewares.GenerateAPIKey(keyLength) apiKey, err := middlewares.GenerateAPIKey(keyLength)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to generate api key: %s", err) return "", fmt.Errorf("unable to generate api key: %s", err)
} }
_, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to create blocker: %s", err) return "", fmt.Errorf("unable to create blocker: %s", err)
@ -322,7 +345,6 @@ func TestUnknownPath(t *testing.T) {
router.ServeHTTP(w, req) router.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code) assert.Equal(t, 404, w.Code)
} }
/* /*
@ -348,6 +370,7 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
} }
tempDir, _ := os.MkdirTemp("", "crowdsec_tests") tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
t.Cleanup(func() { os.RemoveAll(tempDir) }) t.Cleanup(func() { os.RemoveAll(tempDir) })
dbconfig := csconfig.DatabaseCfg{ dbconfig := csconfig.DatabaseCfg{
@ -370,10 +393,12 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil { if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil {
t.Fatal(err) t.Fatal(err)
} }
api, err := NewServer(&cfg) api, err := NewServer(&cfg)
if err != nil { if err != nil {
t.Fatalf("failed to create api : %s", err) t.Fatalf("failed to create api : %s", err)
} }
if api == nil { if api == nil {
t.Fatalf("failed to create api #2 is nbill") t.Fatalf("failed to create api #2 is nbill")
} }
@ -397,11 +422,9 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
t.Fatalf("expected %s in %s", expectedStr, string(data)) t.Fatalf("expected %s in %s", expectedStr, string(data))
} }
} }
} }
func TestLoggingErrorToFileConfig(t *testing.T) { func TestLoggingErrorToFileConfig(t *testing.T) {
/*declare settings*/ /*declare settings*/
maxAge := "1h" maxAge := "1h"
flushConfig := csconfig.FlushDBCfg{ flushConfig := csconfig.FlushDBCfg{
@ -409,6 +432,7 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
} }
tempDir, _ := os.MkdirTemp("", "crowdsec_tests") tempDir, _ := os.MkdirTemp("", "crowdsec_tests")
t.Cleanup(func() { os.RemoveAll(tempDir) }) t.Cleanup(func() { os.RemoveAll(tempDir) })
dbconfig := csconfig.DatabaseCfg{ dbconfig := csconfig.DatabaseCfg{
@ -434,6 +458,7 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("failed to create api : %s", err) t.Fatalf("failed to create api : %s", err)
} }
if api == nil { if api == nil {
t.Fatalf("failed to create api #2 is nbill") t.Fatalf("failed to create api #2 is nbill")
} }

View file

@ -6,13 +6,14 @@ import (
"net/http" "net/http"
"github.com/alexliesenfeld/health" "github.com/alexliesenfeld/health"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1" v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers/v1"
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/csplugin"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
) )
type Controller struct { type Controller struct {

View file

@ -10,15 +10,15 @@ import (
"time" "time"
jwt "github.com/appleboy/gin-jwt/v2" jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt"
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/csplugin"
"github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt"
log "github.com/sirupsen/logrus"
) )
func FormatOneAlert(alert *ent.Alert) *models.Alert { func FormatOneAlert(alert *ent.Alert) *models.Alert {

View file

@ -7,11 +7,12 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/fflag" "github.com/crowdsecurity/crowdsec/pkg/fflag"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
) )
// Format decisions for the bouncers // Format decisions for the bouncers

View file

@ -3,9 +3,10 @@ package v1
import ( import (
"net/http" "net/http"
"github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/crowdsecurity/crowdsec/pkg/database"
) )
func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) { func (c *Controller) HandleDBErrors(gctx *gin.Context, err error) {

View file

@ -3,10 +3,11 @@ package v1
import ( import (
"net/http" "net/http"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt" "github.com/go-openapi/strfmt"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
) )
func (c *Controller) CreateMachine(gctx *gin.Context) { func (c *Controller) CreateMachine(gctx *gin.Context) {

View file

@ -35,8 +35,11 @@ var LapiBouncerHits = prometheus.NewCounterVec(
[]string{"bouncer", "route", "method"}, []string{"bouncer", "route", "method"},
) )
/* keep track of the number of calls (per bouncer) that lead to nil/non-nil responses. /*
while it's not exact, it's a good way to know - when you have a rutpure bouncer - what is the rate of ok/ko answers you got from lapi*/ keep track of the number of calls (per bouncer) that lead to nil/non-nil responses.
while it's not exact, it's a good way to know - when you have a rutpure bouncer - what is the rate of ok/ko answers you got from lapi
*/
var LapiNilDecisions = prometheus.NewCounterVec( var LapiNilDecisions = prometheus.NewCounterVec(
prometheus.CounterOpts{ prometheus.CounterOpts{
Name: "cs_lapi_decisions_ko_total", Name: "cs_lapi_decisions_ko_total",

View file

@ -4,8 +4,9 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
) )
var ( var (

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
const ( const (
@ -91,9 +92,9 @@ func TestGetDecisionFilters(t *testing.T) {
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err := readDecisionsGetResp(w) decisions, code, err := readDecisionsGetResp(w)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 2, len(decisions)) assert.Len(t, decisions, 2)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, "91.121.79.179", *decisions[0].Value)
assert.Equal(t, int64(1), decisions[0].ID) assert.Equal(t, int64(1), decisions[0].ID)
@ -106,9 +107,9 @@ func TestGetDecisionFilters(t *testing.T) {
w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w) decisions, code, err = readDecisionsGetResp(w)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 2, len(decisions)) assert.Len(t, decisions, 2)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, "91.121.79.179", *decisions[0].Value)
assert.Equal(t, int64(1), decisions[0].ID) assert.Equal(t, int64(1), decisions[0].ID)
@ -124,9 +125,9 @@ func TestGetDecisionFilters(t *testing.T) {
w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w) decisions, code, err = readDecisionsGetResp(w)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 1, len(decisions)) assert.Len(t, decisions, 1)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, "91.121.79.179", *decisions[0].Value)
assert.Equal(t, int64(1), decisions[0].ID) assert.Equal(t, int64(1), decisions[0].ID)
@ -139,9 +140,9 @@ func TestGetDecisionFilters(t *testing.T) {
w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w) decisions, code, err = readDecisionsGetResp(w)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 1, len(decisions)) assert.Len(t, decisions, 1)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
assert.Equal(t, "91.121.79.179", *decisions[0].Value) assert.Equal(t, "91.121.79.179", *decisions[0].Value)
assert.Equal(t, int64(1), decisions[0].ID) assert.Equal(t, int64(1), decisions[0].ID)
@ -153,12 +154,11 @@ func TestGetDecisionFilters(t *testing.T) {
w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w) decisions, code, err = readDecisionsGetResp(w)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 2, len(decisions)) assert.Len(t, decisions, 2)
assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179") assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179")
assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.178") assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.178")
} }
func TestGetDecision(t *testing.T) { func TestGetDecision(t *testing.T) {
@ -171,9 +171,9 @@ func TestGetDecision(t *testing.T) {
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err := readDecisionsGetResp(w) decisions, code, err := readDecisionsGetResp(w)
assert.Nil(t, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 3, len(decisions)) assert.Len(t, decisions, 3)
/*decisions get doesn't perform deduplication*/ /*decisions get doesn't perform deduplication*/
assert.Equal(t, "crowdsecurity/test", *decisions[0].Scenario) assert.Equal(t, "crowdsecurity/test", *decisions[0].Scenario)
assert.Equal(t, "127.0.0.1", *decisions[0].Value) assert.Equal(t, "127.0.0.1", *decisions[0].Value)
@ -190,7 +190,7 @@ func TestGetDecision(t *testing.T) {
// Get Decision with invalid filter. It should ignore this filter // Get Decision with invalid filter. It should ignore this filter
w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, 3, len(decisions)) assert.Len(t, decisions, 3)
} }
func TestDeleteDecisionByID(t *testing.T) { func TestDeleteDecisionByID(t *testing.T) {
@ -202,47 +202,47 @@ func TestDeleteDecisionByID(t *testing.T) {
//Have one alerts //Have one alerts
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err := readDecisionsStreamResp(w) decisions, code, err := readDecisionsStreamResp(w)
assert.Equal(t, err, nil) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 0, len(decisions["deleted"])) assert.Empty(t, decisions["deleted"])
assert.Equal(t, 1, len(decisions["new"])) assert.Len(t, decisions["new"], 1)
// Delete alert with Invalid ID // Delete alert with Invalid ID
w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD) w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD)
assert.Equal(t, 400, w.Code) assert.Equal(t, 400, w.Code)
err_resp, _, err := readDecisionsErrorResp(w) errResp, _, err := readDecisionsErrorResp(w)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "decision_id must be valid integer", err_resp["message"]) assert.Equal(t, "decision_id must be valid integer", errResp["message"])
// Delete alert with ID that not exist // Delete alert with ID that not exist
w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD) w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD)
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
err_resp, _, err = readDecisionsErrorResp(w) errResp, _, err = readDecisionsErrorResp(w)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", err_resp["message"]) assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"])
//Have one alerts //Have one alerts
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 0, len(decisions["deleted"])) assert.Empty(t, decisions["deleted"])
assert.Equal(t, 1, len(decisions["new"])) assert.Len(t, decisions["new"], 1)
// Delete alert with valid ID // Delete alert with valid ID
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD) w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
resp, _, err := readDecisionsDeleteResp(w) resp, _, err := readDecisionsDeleteResp(w)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, resp.NbDeleted, "1") assert.Equal(t, "1", resp.NbDeleted)
//Have one alert (because we delete an alert that has dup targets) //Have one alert (because we delete an alert that has dup targets)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 0, len(decisions["deleted"])) assert.Empty(t, decisions["deleted"])
assert.Equal(t, 1, len(decisions["new"])) assert.Len(t, decisions["new"], 1)
} }
func TestDeleteDecision(t *testing.T) { func TestDeleteDecision(t *testing.T) {
@ -254,16 +254,16 @@ func TestDeleteDecision(t *testing.T) {
// Delete alert with Invalid filter // Delete alert with Invalid filter
w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD)
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
err_resp, _, err := readDecisionsErrorResp(w) errResp, _, err := readDecisionsErrorResp(w)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, err_resp["message"], "'test' doesn't exist: invalid filter") assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"])
// Delete all alert // Delete all alert
w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD) w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
resp, _, err := readDecisionsDeleteResp(w) resp, _, err := readDecisionsDeleteResp(w)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, resp.NbDeleted, "3") assert.Equal(t, "3", resp.NbDeleted)
} }
func TestStreamStartDecisionDedup(t *testing.T) { func TestStreamStartDecisionDedup(t *testing.T) {
@ -276,10 +276,10 @@ func TestStreamStartDecisionDedup(t *testing.T) {
// Get Stream, we only get one decision (the longest one) // Get Stream, we only get one decision (the longest one)
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err := readDecisionsStreamResp(w) decisions, code, err := readDecisionsStreamResp(w)
assert.Equal(t, nil, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 0, len(decisions["deleted"])) assert.Empty(t, decisions["deleted"])
assert.Equal(t, 1, len(decisions["new"])) assert.Len(t, decisions["new"], 1)
assert.Equal(t, int64(3), decisions["new"][0].ID) assert.Equal(t, int64(3), decisions["new"][0].ID)
assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "test", *decisions["new"][0].Origin)
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
@ -291,10 +291,10 @@ func TestStreamStartDecisionDedup(t *testing.T) {
// Get Stream, we only get one decision (the longest one, id=2) // Get Stream, we only get one decision (the longest one, id=2)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, nil, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 0, len(decisions["deleted"])) assert.Empty(t, decisions["deleted"])
assert.Equal(t, 1, len(decisions["new"])) assert.Len(t, decisions["new"], 1)
assert.Equal(t, int64(2), decisions["new"][0].ID) assert.Equal(t, int64(2), decisions["new"][0].ID)
assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "test", *decisions["new"][0].Origin)
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
@ -306,10 +306,10 @@ func TestStreamStartDecisionDedup(t *testing.T) {
// And get the remaining decision (1) // And get the remaining decision (1)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, nil, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 0, len(decisions["deleted"])) assert.Empty(t, decisions["deleted"])
assert.Equal(t, 1, len(decisions["new"])) assert.Len(t, decisions["new"], 1)
assert.Equal(t, int64(1), decisions["new"][0].ID) assert.Equal(t, int64(1), decisions["new"][0].ID)
assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "test", *decisions["new"][0].Origin)
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
@ -321,13 +321,13 @@ func TestStreamStartDecisionDedup(t *testing.T) {
//and now we only get a deleted decision //and now we only get a deleted decision
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, nil, err) require.NoError(t, err)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, 1, len(decisions["deleted"])) assert.Len(t, decisions["deleted"], 1)
assert.Equal(t, int64(1), decisions["deleted"][0].ID) assert.Equal(t, int64(1), decisions["deleted"][0].ID)
assert.Equal(t, "test", *decisions["deleted"][0].Origin) assert.Equal(t, "test", *decisions["deleted"][0].Origin)
assert.Equal(t, "127.0.0.1", *decisions["deleted"][0].Value) assert.Equal(t, "127.0.0.1", *decisions["deleted"][0].Value)
assert.Equal(t, 0, len(decisions["new"])) assert.Empty(t, decisions["new"])
} }
type DecisionCheck struct { type DecisionCheck struct {

View file

@ -91,5 +91,4 @@ func TestLogin(t *testing.T) {
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "\"token\"") assert.Contains(t, w.Body.String(), "\"token\"")
assert.Contains(t, w.Body.String(), "\"expire\"") assert.Contains(t, w.Body.String(), "\"expire\"")
} }

View file

@ -49,7 +49,6 @@ func TestCreateMachine(t *testing.T) {
assert.Equal(t, 201, w.Code) assert.Equal(t, 201, w.Code)
assert.Equal(t, "", w.Body.String()) assert.Equal(t, "", w.Body.String())
} }
func TestCreateMachineWithForwardedFor(t *testing.T) { func TestCreateMachineWithForwardedFor(t *testing.T) {
@ -78,6 +77,7 @@ func TestCreateMachineWithForwardedFor(t *testing.T) {
if err != nil { if err != nil {
log.Fatalf("Could not get machine IP : %s", err) log.Fatalf("Could not get machine IP : %s", err)
} }
assert.Equal(t, "1.1.1.1", ip) assert.Equal(t, "1.1.1.1", ip)
} }
@ -165,5 +165,4 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
assert.Equal(t, 403, w.Code) assert.Equal(t, 403, w.Code)
assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String()) assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String())
} }

View file

@ -8,18 +8,19 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
) )
const ( const (
APIKeyHeader = "X-Api-Key" APIKeyHeader = "X-Api-Key"
bouncerContextKey = "bouncer_info" bouncerContextKey = "bouncer_info"
// max allowed by bcrypt 72 = 54 bytes in base64 // max allowed by bcrypt 72 = 54 bytes in base64
dummyAPIKeySize = 54 dummyAPIKeySize = 54
) )
type APIKey struct { type APIKey struct {

View file

@ -10,15 +10,16 @@ import (
"time" "time"
jwt "github.com/appleboy/gin-jwt/v2" jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
) )
var identityKey = "id" var identityKey = "id"
@ -46,16 +47,12 @@ func IdentityHandler(c *gin.Context) interface{} {
} }
} }
type authInput struct { type authInput struct {
machineID string machineID string
clientMachine *ent.Machine clientMachine *ent.Machine
scenariosInput []string scenariosInput []string
} }
func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
ret := authInput{} ret := authInput{}
@ -123,8 +120,6 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
return &ret, nil return &ret, nil
} }
func (j *JWT) authPlain(c *gin.Context) (*authInput, error) { func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
var loginInput models.WatcherAuthRequest var loginInput models.WatcherAuthRequest
var err error var err error
@ -169,7 +164,6 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
return &ret, nil return &ret, nil
} }
func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
var err error var err error
var auth *authInput var auth *authInput

View file

@ -1,27 +0,0 @@
package apiserver
import (
"crypto/tls"
"fmt"
log "github.com/sirupsen/logrus"
)
func getTLSAuthType(authType string) (tls.ClientAuthType, error) {
switch authType {
case "NoClientCert":
return tls.NoClientCert, nil
case "RequestClientCert":
log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
return tls.RequestClientCert, nil
case "RequireAnyClientCert":
log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
return tls.RequireAnyClientCert, nil
case "VerifyClientCertIfGiven":
return tls.VerifyClientCertIfGiven, nil
case "RequireAndVerifyClientCert":
return tls.RequireAndVerifyClientCert, nil
default:
return 0, fmt.Errorf("unknown TLS client_verification value: %s", authType)
}
}

View file

@ -212,18 +212,6 @@ type LocalApiServerCfg struct {
CapiWhitelists *CapiWhitelist `yaml:"-"` CapiWhitelists *CapiWhitelist `yaml:"-"`
} }
type TLSCfg struct {
CertFilePath string `yaml:"cert_file"`
KeyFilePath string `yaml:"key_file"`
ClientVerification string `yaml:"client_verification,omitempty"`
ServerName string `yaml:"server_name"`
CACertPath string `yaml:"ca_cert_path"`
AllowedAgentsOU []string `yaml:"agents_allowed_ou"`
AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"`
CRLPath string `yaml:"crl_path"`
CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"`
}
func (c *Config) LoadAPIServer() error { func (c *Config) LoadAPIServer() error {
if c.DisableAPI { if c.DisableAPI {
log.Warning("crowdsec local API is disabled from flag") log.Warning("crowdsec local API is disabled from flag")
@ -243,13 +231,16 @@ func (c *Config) LoadAPIServer() error {
if !*c.API.Server.Enable { if !*c.API.Server.Enable {
log.Warning("crowdsec local API is disabled because 'enable' is set to false") log.Warning("crowdsec local API is disabled because 'enable' is set to false")
c.DisableAPI = true c.DisableAPI = true
return nil
} }
if c.DisableAPI { if c.DisableAPI {
return nil return nil
} }
if c.API.Server.ListenURI == "" {
return fmt.Errorf("no listen_uri specified")
}
//inherit log level from common, then api->server //inherit log level from common, then api->server
var logLevel log.Level var logLevel log.Level
if c.API.Server.LogLevel != nil { if c.API.Server.LogLevel != nil {

View file

@ -219,7 +219,9 @@ func TestLoadAPIServer(t *testing.T) {
input: &Config{ input: &Config{
Self: []byte(configData), Self: []byte(configData),
API: &APICfg{ API: &APICfg{
Server: &LocalApiServerCfg{}, Server: &LocalApiServerCfg{
ListenURI: "http://crowdsec.api",
},
}, },
Common: &CommonCfg{ Common: &CommonCfg{
LogDir: "./testdata/", LogDir: "./testdata/",

87
pkg/csconfig/tls.go Normal file
View file

@ -0,0 +1,87 @@
package csconfig
import (
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"time"
log "github.com/sirupsen/logrus"
)
type TLSCfg struct {
CertFilePath string `yaml:"cert_file"`
KeyFilePath string `yaml:"key_file"`
ClientVerification string `yaml:"client_verification,omitempty"`
ServerName string `yaml:"server_name"`
CACertPath string `yaml:"ca_cert_path"`
AllowedAgentsOU []string `yaml:"agents_allowed_ou"`
AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"`
CRLPath string `yaml:"crl_path"`
CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"`
}
func (t *TLSCfg) GetAuthType() (tls.ClientAuthType, error) {
if t.ClientVerification == "" {
// sounds like a sane default: verify client cert if given, but don't make it mandatory
return tls.VerifyClientCertIfGiven, nil
}
switch t.ClientVerification {
case "NoClientCert":
return tls.NoClientCert, nil
case "RequestClientCert":
log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
return tls.RequestClientCert, nil
case "RequireAnyClientCert":
log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
return tls.RequireAnyClientCert, nil
case "VerifyClientCertIfGiven":
return tls.VerifyClientCertIfGiven, nil
case "RequireAndVerifyClientCert":
return tls.RequireAndVerifyClientCert, nil
default:
return 0, fmt.Errorf("unknown TLS client_verification value: %s", t.ClientVerification)
}
}
func (t *TLSCfg) GetTLSConfig() (*tls.Config, error) {
if t == nil {
return &tls.Config{}, nil
}
clientAuthType, err := t.GetAuthType()
if err != nil {
return nil, err
}
caCertPool, err := x509.SystemCertPool()
if err != nil {
log.Warnf("Error loading system CA certificates: %s", err)
}
if caCertPool == nil {
caCertPool = x509.NewCertPool()
}
// the > condition below is a weird way to say "if a client certificate is required"
// see https://pkg.go.dev/crypto/tls#ClientAuthType
if clientAuthType > tls.RequestClientCert && t.CACertPath != "" {
log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String())
caCert, err := os.ReadFile(t.CACertPath)
if err != nil {
return nil, fmt.Errorf("while opening cert file: %w", err)
}
caCertPool.AppendCertsFromPEM(caCert)
}
return &tls.Config{
ServerName: t.ServerName, //should it be removed ?
ClientAuth: clientAuthType,
ClientCAs: caCertPool,
MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details
}, nil
}

View file

@ -0,0 +1,51 @@
#!/usr/bin/env bats
# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si:
set -u
setup_file() {
load "../lib/setup_file.sh"
}
teardown_file() {
load "../lib/teardown_file.sh"
}
setup() {
load "../lib/setup.sh"
load "../lib/bats-file/load.bash"
./instance-data load
}
teardown() {
./instance-crowdsec stop
}
#----------
# Tests for LAPI configuration and startup
@test "lapi (.api.server.enable=false)" {
rune -0 config_set '.api.server.enable=false'
rune -1 "${CROWDSEC}" -no-cs
assert_stderr --partial "You must run at least the API Server or crowdsec"
}
@test "lapi (no .api.server.listen_uri)" {
rune -0 config_set 'del(.api.server.listen_uri)'
rune -1 "${CROWDSEC}" -no-cs
assert_stderr --partial "no listen_uri specified"
}
@test "lapi (bad .api.server.listen_uri)" {
rune -0 config_set '.api.server.listen_uri="127.0.0.1:-80"'
rune -1 "${CROWDSEC}" -no-cs
assert_stderr --partial "while starting API server: listening on 127.0.0.1:-80: listen tcp: address -80: invalid port"
}
@test "lapi (listen on random port)" {
config_set '.common.log_media="stdout"'
rune -0 config_set '.api.server.listen_uri="127.0.0.1:0"'
rune -0 wait-for --err "CrowdSec Local API listening on 127.0.0.1:" "${CROWDSEC}" -no-cs
}

View file

@ -15,7 +15,7 @@ setup() {
load "../lib/setup.sh" load "../lib/setup.sh"
load "../lib/bats-file/load.bash" load "../lib/bats-file/load.bash"
./instance-data load ./instance-data load
./instance-crowdsec start # don't run crowdsec here, not all tests require a running instance
} }
teardown() { teardown() {
@ -204,6 +204,7 @@ teardown() {
} }
@test "cscli lapi status" { @test "cscli lapi status" {
rune -0 ./instance-crowdsec start
rune -0 cscli lapi status rune -0 cscli lapi status
assert_stderr --partial "Loaded credentials from" assert_stderr --partial "Loaded credentials from"
@ -260,6 +261,7 @@ teardown() {
} }
@test "cscli - bad LAPI password" { @test "cscli - bad LAPI password" {
rune -0 ./instance-crowdsec start
LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path') LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path')
config_set "${LOCAL_API_CREDENTIALS}" '.password="meh"' config_set "${LOCAL_API_CREDENTIALS}" '.password="meh"'
@ -269,6 +271,7 @@ teardown() {
} }
@test "cscli metrics" { @test "cscli metrics" {
rune -0 ./instance-crowdsec start
rune -0 cscli lapi status rune -0 cscli lapi status
rune -0 cscli metrics rune -0 cscli metrics
assert_output --partial "Route" assert_output --partial "Route"
@ -297,6 +300,7 @@ teardown() {
} }
@test "cscli explain" { @test "cscli explain" {
rune -0 ./instance-crowdsec start
line="Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authentication failure; logname= uid=0 euid=0 tty=ssh ruser= rhost=1.2.3.4" line="Sep 19 18:33:22 scw-d95986 sshd[24347]: pam_unix(sshd:auth): authentication failure; logname= uid=0 euid=0 tty=ssh ruser= rhost=1.2.3.4"
rune -0 cscli parsers install crowdsecurity/syslog-logs rune -0 cscli parsers install crowdsecurity/syslog-logs