apiserver/apiclient: compact tests (#2694)

* apiserver/apiclient: compact tests
* update golangci-lint configuration
This commit is contained in:
mmetc 2024-01-04 17:10:36 +01:00 committed by GitHub
parent 1c03fbe99e
commit da746f77d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 563 additions and 919 deletions

View file

@ -9,6 +9,13 @@ run:
- pkg/yamlpatch/merge_test.go
linters-settings:
gci:
sections:
- standard
- default
- prefix(github.com/crowdsecurity)
- prefix(github.com/crowdsecurity/crowdsec)
gocyclo:
min-complexity: 30

View file

@ -5,13 +5,14 @@ import (
"fmt"
"net/http"
"net/url"
"reflect"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/version"
"github.com/crowdsecurity/crowdsec/pkg/models"
@ -25,12 +26,11 @@ func TestAlertsListAsMachine(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
log.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -39,19 +39,16 @@ func TestAlertsListAsMachine(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
log.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
defer teardown()
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
if r.URL.RawQuery == "ip=1.2.3.4" {
testMethod(t, r, "GET")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `null`)
return
}
@ -107,36 +104,26 @@ func TestAlertsListAsMachine(t *testing.T) {
]`)
})
tcapacity := int32(5)
tduration := "59m49.264032632s"
torigin := "crowdsec"
tscenario := "crowdsecurity/ssh-bf"
tscope := "Ip"
ttype := "ban"
tvalue := "1.1.1.172"
ttimestamp := "2020-11-28 10:20:46 +0000 UTC"
teventscount := int32(6)
tleakspeed := "10s"
tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761"
tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"
tscenarioversion := "0.1"
tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100"
tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100"
expected := models.GetAlertsResponse{
&models.Alert{
Capacity: &tcapacity,
Capacity: ptr.Of(int32(5)),
CreatedAt: "2020-11-28T10:20:47+01:00",
Decisions: []*models.Decision{
{
Duration: &tduration,
Duration: ptr.Of("59m49.264032632s"),
ID: 1,
Origin: &torigin,
Origin: ptr.Of("crowdsec"),
Scenario: &tscenario,
Scope: &tscope,
Simulated: new(bool), //false,
Type: &ttype,
Simulated: ptr.Of(false),
Type: ptr.Of("ban"),
Value: &tvalue,
},
},
@ -167,16 +154,16 @@ func TestAlertsListAsMachine(t *testing.T) {
Timestamp: &ttimestamp,
},
},
EventsCount: &teventscount,
EventsCount: ptr.Of(int32(6)),
ID: 1,
Leakspeed: &tleakspeed,
Leakspeed: ptr.Of("10s"),
MachineID: "test",
Message: &tmessage,
Remediation: false,
Scenario: &tscenario,
ScenarioHash: &tscenariohash,
ScenarioVersion: &tscenarioversion,
Simulated: new(bool), //(false),
ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"),
ScenarioVersion: ptr.Of("0.1"),
Simulated: ptr.Of(false),
Source: &models.Source{
AsName: "Cloudflare Inc",
AsNumber: "",
@ -188,8 +175,8 @@ func TestAlertsListAsMachine(t *testing.T) {
Scope: &tscope,
Value: &tvalue,
},
StartAt: &tstartat,
StopAt: &tstopat,
StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"),
StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"),
},
}
@ -198,30 +185,16 @@ func TestAlertsListAsMachine(t *testing.T) {
//log.Debugf("expected : -> %s", spew.Sdump(expected))
//first one returns data
alerts, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{})
if err != nil {
log.Errorf("test Unable to list alerts : %+v", err)
}
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, expected, *alerts)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if !reflect.DeepEqual(*alerts, expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
}
//this one doesn't
filter := AlertsListOpts{IPEquals: new(string)}
*filter.IPEquals = "1.2.3.4"
filter := AlertsListOpts{IPEquals: ptr.Of("1.2.3.4")}
alerts, resp, err = client.Alerts.List(context.Background(), filter)
if err != nil {
log.Errorf("test Unable to list alerts : %+v", err)
}
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Empty(t, *alerts)
}
@ -236,9 +209,7 @@ func TestAlertsGetAsMachine(t *testing.T) {
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
log.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -247,12 +218,10 @@ func TestAlertsGetAsMachine(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
log.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
defer teardown()
mux.HandleFunc("/alerts/2", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
w.WriteHeader(http.StatusNotFound)
@ -312,34 +281,24 @@ func TestAlertsGetAsMachine(t *testing.T) {
}`)
})
tcapacity := int32(5)
tduration := "59m49.264032632s"
torigin := "crowdsec"
tscenario := "crowdsecurity/ssh-bf"
tscope := "Ip"
ttype := "ban"
tvalue := "1.1.1.172"
ttimestamp := "2020-11-28 10:20:46 +0000 UTC"
teventscount := int32(6)
tleakspeed := "10s"
tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761"
tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"
tscenarioversion := "0.1"
tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100"
tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100"
expected := &models.Alert{
Capacity: &tcapacity,
Capacity: ptr.Of(int32(5)),
CreatedAt: "2020-11-28T10:20:47+01:00",
Decisions: []*models.Decision{
{
Duration: &tduration,
Duration: ptr.Of("59m49.264032632s"),
ID: 1,
Origin: &torigin,
Origin: ptr.Of("crowdsec"),
Scenario: &tscenario,
Scope: &tscope,
Simulated: new(bool), //false,
Simulated: ptr.Of(false),
Type: &ttype,
Value: &tvalue,
},
@ -371,16 +330,16 @@ func TestAlertsGetAsMachine(t *testing.T) {
Timestamp: &ttimestamp,
},
},
EventsCount: &teventscount,
EventsCount: ptr.Of(int32(6)),
ID: 1,
Leakspeed: &tleakspeed,
Leakspeed: ptr.Of("10s"),
MachineID: "test",
Message: &tmessage,
Message: ptr.Of("Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761"),
Remediation: false,
Scenario: &tscenario,
ScenarioHash: &tscenariohash,
ScenarioVersion: &tscenarioversion,
Simulated: new(bool), //(false),
ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"),
ScenarioVersion: ptr.Of("0.1"),
Simulated: ptr.Of(false),
Source: &models.Source{
AsName: "Cloudflare Inc",
AsNumber: "",
@ -392,24 +351,18 @@ func TestAlertsGetAsMachine(t *testing.T) {
Scope: &tscope,
Value: &tvalue,
},
StartAt: &tstartat,
StopAt: &tstopat,
StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"),
StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"),
}
alerts, resp, err := client.Alerts.GetByID(context.Background(), 1)
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if !reflect.DeepEqual(*alerts, *expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *alerts)
//fail
_, _, err = client.Alerts.GetByID(context.Background(), 2)
assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found")
cstest.RequireErrorMessage(t, err, "API error: object not found")
}
func TestAlertsCreateAsMachine(t *testing.T) {
@ -420,17 +373,17 @@ func TestAlertsCreateAsMachine(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`["3"]`))
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
log.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -439,10 +392,7 @@ func TestAlertsCreateAsMachine(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
log.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
defer teardown()
@ -452,13 +402,8 @@ func TestAlertsCreateAsMachine(t *testing.T) {
expected := &models.AddAlertsResponse{"3"}
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if !reflect.DeepEqual(*alerts, *expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *alerts)
}
func TestAlertsDeleteAsMachine(t *testing.T) {
@ -469,18 +414,18 @@ func TestAlertsDeleteAsMachine(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "DELETE")
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message":"0 deleted alerts"}`))
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
log.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -489,25 +434,16 @@ func TestAlertsDeleteAsMachine(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
log.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
defer teardown()
alert := AlertsDeleteOpts{IPEquals: new(string)}
*alert.IPEquals = "1.2.3.4"
alert := AlertsDeleteOpts{IPEquals: ptr.Of("1.2.3.4")}
alerts, resp, err := client.Alerts.Delete(context.Background(), alert)
require.NoError(t, err)
expected := &models.DeleteAlertsResponse{NbDeleted: ""}
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if !reflect.DeepEqual(*alerts, *expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *alerts)
}

View file

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/version"
@ -24,13 +25,11 @@ type BasicMockPayload struct {
}
func getLoginsForMockErrorCases() map[string]int {
loginsForMockErrorCases := map[string]int{
return map[string]int{
"login_400": http.StatusBadRequest,
"login_409": http.StatusConflict,
"login_500": http.StatusInternalServerError,
}
return loginsForMockErrorCases
}
func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
@ -49,7 +48,7 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
w.WriteHeader(http.StatusBadRequest)
}
responseBody := ""
var responseBody string
responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID]
if !hasFoundErrorMock {
@ -58,6 +57,7 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
} else {
responseBody = fmt.Sprintf("Error %d", responseCode)
}
log.Printf("MockServerReceived > %s // Login : [%s] => Mux response [%d]", newStr, payload.MachineID, responseCode)
w.WriteHeader(responseCode)
@ -76,14 +76,13 @@ func TestWatcherRegister(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
//body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
initBasicMuxMock(t, mux, "/watchers")
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
// Valid Registration : should retrieve the client and no err
clientconfig := Config{
@ -95,9 +94,7 @@ func TestWatcherRegister(t *testing.T) {
}
client, err := RegisterClient(&clientconfig, &http.Client{})
if client == nil || err != nil {
t.Fatalf("while registering client : %s", err)
}
require.NoError(t, err)
log.Printf("->%T", client)
@ -107,11 +104,8 @@ func TestWatcherRegister(t *testing.T) {
clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
client, err = RegisterClient(&clientconfig, &http.Client{})
if client != nil || err == nil {
t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest)
} else {
log.Printf("The RegisterClient function handled the error code %d as expected \n\r", errorCodeToTest)
}
require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest)
require.Error(t, err, "error expected for the response code %d", errorCodeToTest)
}
}
@ -126,9 +120,7 @@ func TestWatcherAuth(t *testing.T) {
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok auth
clientConfig := &Config{
@ -139,34 +131,27 @@ func TestWatcherAuth(t *testing.T) {
VersionPrefix: "v1",
Scenarios: []string{"crowdsecurity/test"},
}
client, err := NewClient(clientConfig)
if err != nil {
t.Fatalf("new api client: %s", err)
}
client, err := NewClient(clientConfig)
require.NoError(t, err)
_, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
MachineID: &clientConfig.MachineID,
Password: &clientConfig.Password,
Scenarios: clientConfig.Scenarios,
})
if err != nil {
t.Fatalf("unexpect auth err 0: %s", err)
}
require.NoError(t, err)
// Testing error handling on AuthenticateWatcher (400, 409): should retrieve an error
// Not testing 500 because it loops and try to re-autehnticate. But you can test it manually by adding it in array
errorCodesToTest := [2]int{http.StatusBadRequest, http.StatusConflict}
for _, errorCodeToTest := range errorCodesToTest {
clientConfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
client, err := NewClient(clientConfig)
require.NoError(t, err)
if err != nil {
t.Fatalf("new api client: %s", err)
}
var resp *Response
_, resp, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
_, resp, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
MachineID: &clientConfig.MachineID,
Password: &clientConfig.Password,
})
@ -175,9 +160,7 @@ func TestWatcherAuth(t *testing.T) {
resp.Response.Body.Close()
bodyBytes, err := io.ReadAll(resp.Response.Body)
if err != nil {
t.Fatalf("error while reading body: %s", err.Error())
}
require.NoError(t, err)
log.Printf(string(bodyBytes))
t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest)
@ -199,10 +182,12 @@ func TestWatcherUnregister(t *testing.T) {
assert.Equal(t, int64(0), r.ContentLength)
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST")
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(r.Body)
newStr := buf.String()
if newStr == `{"machine_id":"test_login","password":"test_password","scenarios":["crowdsecurity/test"]}
` {
@ -217,9 +202,7 @@ func TestWatcherUnregister(t *testing.T) {
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
mycfg := &Config{
MachineID: "test_login",
@ -229,16 +212,12 @@ func TestWatcherUnregister(t *testing.T) {
VersionPrefix: "v1",
Scenarios: []string{"crowdsecurity/test"},
}
client, err := NewClient(mycfg)
if err != nil {
t.Fatalf("new api client: %s", err)
}
client, err := NewClient(mycfg)
require.NoError(t, err)
_, err = client.Auth.UnregisterWatcher(context.Background())
if err != nil {
t.Fatalf("while registering client : %s", err)
}
require.NoError(t, err)
log.Printf("->%T", client)
}
@ -255,6 +234,7 @@ func TestWatcherEnroll(t *testing.T) {
_, _ = buf.ReadFrom(r.Body)
newStr := buf.String()
log.Debugf("body -> %s", newStr)
if newStr == `{"attachment_key":"goodkey","name":"","tags":[],"overwrite":false}
` {
log.Print("good key")
@ -266,17 +246,17 @@ func TestWatcherEnroll(t *testing.T) {
fmt.Fprintf(w, `{"message":"the attachment key provided is not valid"}`)
}
})
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`)
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
mycfg := &Config{
MachineID: "test_login",
@ -286,16 +266,12 @@ func TestWatcherEnroll(t *testing.T) {
VersionPrefix: "v1",
Scenarios: []string{"crowdsecurity/test"},
}
client, err := NewClient(mycfg)
if err != nil {
t.Fatalf("new api client: %s", err)
}
client, err := NewClient(mycfg)
require.NoError(t, err)
_, err = client.Auth.EnrollWatcher(context.Background(), "goodkey", "", []string{}, false)
if err != nil {
t.Fatalf("unexpect enroll err: %s", err)
}
require.NoError(t, err)
_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
assert.Contains(t, err.Error(), "the attachment key provided is not valid", "got %s", err.Error())

View file

@ -9,6 +9,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/ptr"
)
func TestApiAuth(t *testing.T) {
@ -17,6 +20,7 @@ func TestApiAuth(t *testing.T) {
mux, urlx, teardown := setup()
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
if r.Header.Get("X-Api-Key") == "ixu" {
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
w.WriteHeader(http.StatusOK)
@ -26,11 +30,11 @@ func TestApiAuth(t *testing.T) {
w.Write([]byte(`{"message":"access forbidden"}`))
}
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
defer teardown()
@ -40,18 +44,12 @@ func TestApiAuth(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
alert := DecisionsListOpts{IPEquals: new(string)}
*alert.IPEquals = "1.2.3.4"
_, resp, err := newcli.Decisions.List(context.Background(), alert)
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
alert := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")}
_, resp, err := newcli.Decisions.List(context.Background(), alert)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
//ko bad token
auth = &APIKeyTransport{
@ -59,25 +57,21 @@ func TestApiAuth(t *testing.T) {
}
newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
_, resp, err = newcli.Decisions.List(context.Background(), alert)
log.Infof("--> %s", err)
if resp.Response.StatusCode != http.StatusForbidden {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
assert.Equal(t, http.StatusForbidden, resp.Response.StatusCode)
cstest.RequireErrorMessage(t, err, "API error: access forbidden")
assert.Contains(t, err.Error(), "API error: access forbidden")
//ko empty token
auth = &APIKeyTransport{}
newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
_, _, err = newcli.Decisions.List(context.Background(), alert)
require.Error(t, err)

View file

@ -8,19 +8,19 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/version"
)
func TestNewRequestInvalid(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
//missing slash in uri
apiURL, err := url.Parse(urlx)
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -29,9 +29,8 @@ func TestNewRequestInvalid(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
@ -44,17 +43,16 @@ func TestNewRequestInvalid(t *testing.T) {
})
_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
assert.Contains(t, err.Error(), `building request: BaseURL must have a trailing slash, but `)
cstest.RequireErrorContains(t, err, "building request: BaseURL must have a trailing slash, but ")
}
func TestNewRequestTimeout(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
//missing slash in uri
// missing slash in uri
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -63,9 +61,8 @@ func TestNewRequestTimeout(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second)
@ -75,5 +72,5 @@ func TestNewRequestTimeout(t *testing.T) {
defer cancel()
_, _, err = client.Alerts.List(ctx, AlertsListOpts{})
assert.Contains(t, err.Error(), `performing request: context deadline exceeded`)
cstest.RequireErrorMessage(t, err, "performing request: context deadline exceeded")
}

View file

@ -11,7 +11,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/version"
)
@ -20,13 +22,13 @@ import (
- each test will then bind handler for the method(s) they want to try
*/
func setup() (mux *http.ServeMux, serverURL string, teardown func()) {
func setup() (*http.ServeMux, string, func()) {
return setupWithPrefix("v1")
}
func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) {
func setupWithPrefix(urlPrefix string) (*http.ServeMux, string, func()) {
// mux is the HTTP request multiplexer used with the test server.
mux = http.NewServeMux()
mux := http.NewServeMux()
baseURLPath := "/" + urlPrefix
apiHandler := http.NewServeMux()
@ -40,19 +42,16 @@ func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, te
func testMethod(t *testing.T, r *http.Request, want string) {
t.Helper()
if got := r.Method; got != want {
t.Errorf("Request method: %v, want %v", got, want)
}
assert.Equal(t, want, r.Method)
}
func TestNewClientOk(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -60,9 +59,8 @@ func TestNewClientOk(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@ -75,22 +73,17 @@ func TestNewClientOk(t *testing.T) {
})
_, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{})
if err != nil {
t.Fatalf("test Unable to list alerts : %+v", err)
}
if resp.Response.StatusCode != http.StatusOK {
t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated)
}
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
}
func TestNewClientKo(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -98,9 +91,8 @@ func TestNewClientKo(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
@ -113,36 +105,36 @@ func TestNewClientKo(t *testing.T) {
})
_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
assert.Contains(t, err.Error(), `API error: bad login/password`)
cstest.RequireErrorContains(t, err, `API error: bad login/password`)
log.Printf("err-> %s", err)
}
func TestNewDefaultClient(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewDefaultClient(apiURL, "/v1", "", nil)
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"code": 401, "message" : "brr"}`))
})
_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
assert.Contains(t, err.Error(), `performing request: API error: brr`)
cstest.RequireErrorMessage(t, err, "performing request: API error: brr")
log.Printf("err-> %s", err)
}
func TestNewClientRegisterKO(t *testing.T) {
apiURL, err := url.Parse("http://127.0.0.1:4242/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
_, err = RegisterClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -150,17 +142,18 @@ func TestNewClientRegisterKO(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
}, &http.Client{})
if runtime.GOOS != "windows" {
assert.Contains(t, fmt.Sprintf("%s", err), "dial tcp 127.0.0.1:4242: connect: connection refused")
cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused")
} else {
assert.Contains(t, fmt.Sprintf("%s", err), " No connection could be made because the target machine actively refused it.")
cstest.RequireErrorContains(t, err, " No connection could be made because the target machine actively refused it.")
}
}
func TestNewClientRegisterOK(t *testing.T) {
log.SetLevel(log.TraceLevel)
mux, urlx, teardown := setup()
mux, urlx, teardown := setup()
defer teardown()
/*mock login*/
@ -171,9 +164,8 @@ func TestNewClientRegisterOK(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := RegisterClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -181,17 +173,15 @@ func TestNewClientRegisterOK(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
}, &http.Client{})
if err != nil {
t.Fatalf("while registering client : %s", err)
}
require.NoError(t, err)
log.Printf("->%T", client)
}
func TestNewClientBadAnswer(t *testing.T) {
log.SetLevel(log.TraceLevel)
mux, urlx, teardown := setup()
mux, urlx, teardown := setup()
defer teardown()
/*mock login*/
@ -200,10 +190,10 @@ func TestNewClientBadAnswer(t *testing.T) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`bad`))
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
_, err = RegisterClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -211,5 +201,5 @@ func TestNewClientBadAnswer(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
}, &http.Client{})
assert.Contains(t, fmt.Sprintf("%s", err), `invalid body: invalid character 'b' looking for beginning of value`)
cstest.RequireErrorContains(t, err, "invalid body: invalid character 'b' looking for beginning of value")
}

View file

@ -5,13 +5,13 @@ import (
"fmt"
"net/http"
"net/url"
"reflect"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/version"
@ -38,10 +38,9 @@ func TestDecisionsList(t *testing.T) {
//no results
}
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -49,55 +48,32 @@ func TestDecisionsList(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tduration := "3h59m55.756182786s"
torigin := "cscli"
tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"
tscope := "Ip"
ttype := "ban"
tvalue := "1.2.3.4"
expected := &models.GetDecisionsResponse{
&models.Decision{
Duration: &tduration,
Duration: ptr.Of("3h59m55.756182786s"),
ID: 4,
Origin: &torigin,
Scenario: &tscenario,
Scope: &tscope,
Type: &ttype,
Value: &tvalue,
Origin: ptr.Of("cscli"),
Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"),
Scope: ptr.Of("Ip"),
Type: ptr.Of("ban"),
Value: ptr.Of("1.2.3.4"),
},
}
//OK decisions
decisionsFilter := DecisionsListOpts{IPEquals: new(string)}
*decisionsFilter.IPEquals = "1.2.3.4"
// OK decisions
decisionsFilter := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")}
decisions, resp, err := newcli.Decisions.List(context.Background(), decisionsFilter)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected)
}
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *decisions)
//Empty return
decisionsFilter = DecisionsListOpts{IPEquals: new(string)}
*decisionsFilter.IPEquals = "1.2.3.5"
decisionsFilter = DecisionsListOpts{IPEquals: ptr.Of("1.2.3.5")}
decisions, resp, err = newcli.Decisions.List(context.Background(), decisionsFilter)
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Empty(t, *decisions)
}
@ -120,6 +96,7 @@ func TestDecisionsStream(t *testing.T) {
}
}
})
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
testMethod(t, r, http.MethodDelete)
@ -129,9 +106,7 @@ func TestDecisionsStream(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -139,63 +114,38 @@ func TestDecisionsStream(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tduration := "3h59m55.756182786s"
torigin := "cscli"
tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"
tscope := "Ip"
ttype := "ban"
tvalue := "1.2.3.4"
expected := &models.DecisionsStreamResponse{
New: models.GetDecisionsResponse{
&models.Decision{
Duration: &tduration,
Duration: ptr.Of("3h59m55.756182786s"),
ID: 4,
Origin: &torigin,
Scenario: &tscenario,
Scope: &tscope,
Type: &ttype,
Value: &tvalue,
Origin: ptr.Of("cscli"),
Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"),
Scope: ptr.Of("Ip"),
Type: ptr.Of("ban"),
Value: ptr.Of("1.2.3.4"),
},
},
}
decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true})
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *decisions)
//and second call, we get empty lists
decisions, resp, err = newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: false})
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Empty(t, decisions.New)
assert.Empty(t, decisions.Deleted)
//delete stream
resp, err = newcli.Decisions.StopStream(context.Background())
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
}
func TestDecisionsStreamV3Compatibility(t *testing.T) {
@ -219,9 +169,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -229,38 +177,30 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tduration := "3h59m55.756182786s"
torigin := "CAPI"
tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"
tscope := "ip"
ttype := "ban"
tvalue := "1.2.3.4"
tvalue1 := "1.2.3.5"
tscenarioDeleted := "deleted"
tdurationDeleted := "1h"
expected := &models.DecisionsStreamResponse{
New: models.GetDecisionsResponse{
&models.Decision{
Duration: &tduration,
Duration: ptr.Of("3h59m55.756182786s"),
Origin: &torigin,
Scenario: &tscenario,
Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"),
Scope: &tscope,
Type: &ttype,
Value: &tvalue,
Value: ptr.Of("1.2.3.4"),
},
},
Deleted: models.GetDecisionsResponse{
&models.Decision{
Duration: &tdurationDeleted,
Duration: ptr.Of("1h"),
Origin: &torigin,
Scenario: &tscenarioDeleted,
Scenario: ptr.Of("deleted"),
Scope: &tscope,
Type: &ttype,
Value: &tvalue1,
Value: ptr.Of("1.2.3.5"),
},
},
}
@ -268,18 +208,8 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
// GetStream is supposed to consume v3 payload and return v2 response
decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true})
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *decisions)
}
func TestDecisionsStreamV3(t *testing.T) {
@ -300,9 +230,7 @@ func TestDecisionsStreamV3(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -310,30 +238,19 @@ func TestDecisionsStreamV3(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tduration := "3h59m55.756182786s"
tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"
tscope := "ip"
tvalue := "1.2.3.4"
tvalue1 := "1.2.3.5"
tdurationBlocklist := "24h"
tnameBlocklist := "blocklist1"
tremediationBlocklist := "ban"
tscopeBlocklist := "ip"
turlBlocklist := "/v3/blocklist"
expected := &modelscapi.GetDecisionsStreamResponse{
New: modelscapi.GetDecisionsStreamResponseNew{
&modelscapi.GetDecisionsStreamResponseNewItem{
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Duration: &tduration,
Value: &tvalue,
Duration: ptr.Of("3h59m55.756182786s"),
Value: ptr.Of("1.2.3.4"),
},
},
Scenario: &tscenario,
Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"),
Scope: &tscope,
},
},
@ -341,18 +258,18 @@ func TestDecisionsStreamV3(t *testing.T) {
&modelscapi.GetDecisionsStreamResponseDeletedItem{
Scope: &tscope,
Decisions: []string{
tvalue1,
"1.2.3.5",
},
},
},
Links: &modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{
{
Duration: &tdurationBlocklist,
Name: &tnameBlocklist,
Remediation: &tremediationBlocklist,
Scope: &tscopeBlocklist,
URL: &turlBlocklist,
Duration: ptr.Of("24h"),
Name: ptr.Of("blocklist1"),
Remediation: ptr.Of("ban"),
Scope: ptr.Of("ip"),
URL: ptr.Of("/v3/blocklist"),
},
},
},
@ -361,18 +278,8 @@ func TestDecisionsStreamV3(t *testing.T) {
// GetStream is supposed to consume v3 payload and return v2 response
decisions, resp, err := newcli.Decisions.GetStreamV3(context.Background(), DecisionsStreamOpts{Startup: true})
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *decisions)
}
func TestDecisionsFromBlocklist(t *testing.T) {
@ -383,10 +290,13 @@ func TestDecisionsFromBlocklist(t *testing.T) {
mux.HandleFunc("/blocklist", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, http.MethodGet)
if r.Header.Get("If-Modified-Since") == "Sun, 01 Jan 2023 01:01:01 GMT" {
w.WriteHeader(http.StatusNotModified)
return
}
if r.Method == http.MethodGet {
w.WriteHeader(http.StatusOK)
w.Write([]byte("1.2.3.4\r\n1.2.3.5"))
@ -394,9 +304,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -404,12 +312,8 @@ func TestDecisionsFromBlocklist(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tvalue1 := "1.2.3.4"
tvalue2 := "1.2.3.5"
tdurationBlocklist := "24h"
tnameBlocklist := "blocklist1"
tremediationBlocklist := "ban"
@ -419,7 +323,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
expected := []*models.Decision{
{
Duration: &tdurationBlocklist,
Value: &tvalue1,
Value: ptr.Of("1.2.3.4"),
Scenario: &tnameBlocklist,
Scope: &tscopeBlocklist,
Type: &tremediationBlocklist,
@ -427,7 +331,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
},
{
Duration: &tdurationBlocklist,
Value: &tvalue2,
Value: ptr.Of("1.2.3.5"),
Scenario: &tnameBlocklist,
Scope: &tscopeBlocklist,
Type: &tremediationBlocklist,
@ -450,13 +354,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
log.Infof("expected : %s, %s, %s, %s, %s", *expected[0].Value, *expected[0].Duration, *expected[0].Scenario, *expected[0].Scope, *expected[0].Type)
log.Infof("decisions: %s, %s, %s, %s, %s", *decisions[1].Value, *decisions[1].Duration, *decisions[1].Scenario, *decisions[1].Scope, *decisions[1].Type)
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(decisions, expected) {
t.Fatalf("returned %+v, want %+v", decisions, expected)
}
assert.Equal(t, expected, decisions)
// test cache control
_, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
@ -466,8 +364,10 @@ func TestDecisionsFromBlocklist(t *testing.T) {
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
}, ptr.Of("Sun, 01 Jan 2023 01:01:01 GMT"))
require.NoError(t, err)
assert.False(t, isModified)
_, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
URL: &turlBlocklist,
Scope: &tscopeBlocklist,
@ -475,6 +375,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
}, ptr.Of("Mon, 02 Jan 2023 01:01:01 GMT"))
require.NoError(t, err)
assert.True(t, isModified)
}
@ -485,6 +386,7 @@ func TestDeleteDecisions(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "DELETE")
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
@ -492,11 +394,12 @@ func TestDeleteDecisions(t *testing.T) {
w.Write([]byte(`{"nbDeleted":"1"}`))
//w.Write([]byte(`{"message":"0 deleted alerts"}`))
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -504,18 +407,13 @@ func TestDeleteDecisions(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
filters := DecisionsDeleteOpts{IPEquals: new(string)}
*filters.IPEquals = "1.2.3.4"
deleted, _, err := client.Decisions.Delete(context.Background(), filters)
if err != nil {
t.Fatalf("unexpected err : %s", err)
}
deleted, _, err := client.Decisions.Delete(context.Background(), filters)
require.NoError(t, err)
assert.Equal(t, "1", deleted.NbDeleted)
defer teardown()
@ -530,22 +428,23 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
ScenariosContaining string
ScenariosNotContaining string
}
tests := []struct {
name string
fields fields
want string
wantErr bool
name string
fields fields
expected string
expectedErr string
}{
{
name: "no filter",
want: baseURLString + "?",
name: "no filter",
expected: baseURLString + "?",
},
{
name: "startup=true",
fields: fields{
Startup: true,
},
want: baseURLString + "?startup=true",
expected: baseURLString + "?startup=true",
},
{
name: "set all params",
@ -555,7 +454,7 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
ScenariosContaining: "ssh",
ScenariosNotContaining: "bf",
},
want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true",
expected: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true",
},
}
@ -568,25 +467,20 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
ScenariosContaining: tt.fields.ScenariosContaining,
ScenariosNotContaining: tt.fields.ScenariosNotContaining,
}
got, err := o.addQueryParamsToURL(baseURLString)
if (err != nil) != tt.wantErr {
t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() error = %v, wantErr %v", err, tt.wantErr)
cstest.RequireErrorContains(t, err, tt.expectedErr)
if tt.expectedErr != "" {
return
}
gotURL, err := url.Parse(got)
if err != nil {
t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err)
}
require.NoError(t, err)
expectedURL, err := url.Parse(tt.want)
if err != nil {
t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err)
}
expectedURL, err := url.Parse(tt.expected)
require.NoError(t, err)
if *gotURL != *expectedURL {
t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() = %v, want %v", *gotURL, *expectedURL)
}
assert.Equal(t, *expectedURL, *gotURL)
})
}
}

View file

@ -10,8 +10,8 @@ import (
"testing"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
@ -22,21 +22,14 @@ type LAPI struct {
router *gin.Engine
loginResp models.WatcherAuthResponse
bouncerKey string
t *testing.T
DBConfig *csconfig.DatabaseCfg
}
func SetupLAPITest(t *testing.T) LAPI {
t.Helper()
router, loginResp, config, err := InitMachineTest(t)
if err != nil {
t.Fatal(err)
}
router, loginResp, config := InitMachineTest(t)
APIKey, err := CreateTestBouncer(config.API.Server.DbConfig)
if err != nil {
t.Fatal(err)
}
APIKey := CreateTestBouncer(t, config.API.Server.DbConfig)
return LAPI{
router: router,
@ -46,24 +39,23 @@ func SetupLAPITest(t *testing.T) LAPI {
}
}
func (l *LAPI) InsertAlertFromFile(path string) *httptest.ResponseRecorder {
alertReader := GetAlertReaderFromFile(path)
return l.RecordResponse(http.MethodPost, "/v1/alerts", alertReader, "password")
func (l *LAPI) InsertAlertFromFile(t *testing.T, path string) *httptest.ResponseRecorder {
alertReader := GetAlertReaderFromFile(t, path)
return l.RecordResponse(t, http.MethodPost, "/v1/alerts", alertReader, "password")
}
func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder {
func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder {
w := httptest.NewRecorder()
req, err := http.NewRequest(verb, url, body)
if err != nil {
l.t.Fatal(err)
}
require.NoError(t, err)
if authType == "apikey" {
switch authType {
case "apikey":
req.Header.Add("X-Api-Key", l.bouncerKey)
} else if authType == "password" {
case "password":
AddAuthHeaders(req, l.loginResp)
} else {
l.t.Fatal("auth type not supported")
default:
t.Fatal("auth type not supported")
}
l.router.ServeHTTP(w, req)
@ -71,29 +63,16 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut
return w
}
func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config, error) {
router, config, err := NewAPITest(t)
if err != nil {
return nil, models.WatcherAuthResponse{}, config, fmt.Errorf("unable to run local API: %s", err)
}
func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) {
router, config := NewAPITest(t)
loginResp := LoginToTestAPI(t, router, config)
loginResp, err := LoginToTestAPI(router, config)
if err != nil {
return nil, models.WatcherAuthResponse{}, config, err
}
return router, loginResp, config, nil
return router, loginResp, config
}
func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherAuthResponse, error) {
body, err := CreateTestMachine(router)
if err != nil {
return models.WatcherAuthResponse{}, err
}
err = ValidateMachine("test", config.API.Server.DbConfig)
if err != nil {
log.Fatalln(err)
}
func LoginToTestAPI(t *testing.T, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse {
body := CreateTestMachine(t, router)
ValidateMachine(t, "test", config.API.Server.DbConfig)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body))
@ -101,12 +80,10 @@ func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherA
router.ServeHTTP(w, req)
loginResp := models.WatcherAuthResponse{}
err = json.NewDecoder(w.Body).Decode(&loginResp)
if err != nil {
return models.WatcherAuthResponse{}, err
}
err := json.NewDecoder(w.Body).Decode(&loginResp)
require.NoError(t, err)
return loginResp, nil
return loginResp
}
func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthResponse) {
@ -116,17 +93,17 @@ func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthRespon
func TestSimulatedAlert(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_minibulk+simul.json")
alertContent := GetAlertReaderFromFile("./tests/alert_minibulk+simul.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk+simul.json")
alertContent := GetAlertReaderFromFile(t, "./tests/alert_minibulk+simul.json")
//exclude decision in simulation mode
w := lapi.RecordResponse("GET", "/v1/alerts?simulated=false", alertContent, "password")
w := lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", alertContent, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `)
assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `)
//include decision in simulation mode
w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", alertContent, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", alertContent, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `)
@ -136,35 +113,29 @@ func TestCreateAlert(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Alert with invalid format
w := lapi.RecordResponse(http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password")
w := lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password")
assert.Equal(t, 400, w.Code)
assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String())
assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Create Alert with invalid input
alertContent := GetAlertReaderFromFile("./tests/invalidAlert_sample.json")
alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json")
w = lapi.RecordResponse(http.MethodPost, "/v1/alerts", alertContent, "password")
w = lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", alertContent, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"validation failure list:\\n0.scenario in body is required\\n0.scenario_hash in body is required\\n0.scenario_version in body is required\\n0.simulated in body is required\\n0.source in body is required\"}", w.Body.String())
assert.Equal(t, `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, w.Body.String())
// Create Valid Alert
w = lapi.InsertAlertFromFile("./tests/alert_sample.json")
w = lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assert.Equal(t, 201, w.Code)
assert.Equal(t, "[\"1\"]", w.Body.String())
assert.Equal(t, `["1"]`, w.Body.String())
}
func TestCreateAlertChannels(t *testing.T) {
apiServer, config, err := NewAPIServer(t)
if err != nil {
log.Fatalln(err)
}
apiServer, config := NewAPIServer(t)
apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert)
apiServer.InitController()
loginResp, err := LoginToTestAPI(apiServer.router, config)
if err != nil {
log.Fatalln(err)
}
loginResp := LoginToTestAPI(t, apiServer.router, config)
lapi := LAPI{router: apiServer.router, loginResp: loginResp}
var (
@ -180,7 +151,7 @@ func TestCreateAlertChannels(t *testing.T) {
wg.Done()
}()
go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json")
go lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json")
wg.Wait()
assert.Len(t, pd.Alert.Decisions, 1)
apiServer.Close()
@ -188,18 +159,18 @@ func TestCreateAlertChannels(t *testing.T) {
func TestAlertListFilters(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json")
alertContent := GetAlertReaderFromFile("./tests/alert_ssh-bf.json")
lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json")
alertContent := GetAlertReaderFromFile(t, "./tests/alert_ssh-bf.json")
//bad filter
w := lapi.RecordResponse("GET", "/v1/alerts?test=test", alertContent, "password")
w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", alertContent, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String())
assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
//get without filters
w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password")
assert.Equal(t, 200, w.Code)
//check alert and decision
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
@ -207,149 +178,149 @@ func TestAlertListFilters(t *testing.T) {
//test decision_type filter (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ban", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test decision_type filter (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ratata", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test scope (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?scope=Ip", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=Ip", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test scope (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?scope=rarara", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=rarara", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test scenario (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test scenario (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test ip (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test ip (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test ip (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?ip=gruueq", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String())
//test range (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test range
w = lapi.RecordResponse("GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test range (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?range=ratata", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=ratata", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String())
//test since (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?since=1h", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1h", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test since (ok but yields no results)
w = lapi.RecordResponse("GET", "/v1/alerts?since=1ns", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1ns", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test since (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?since=1zuzu", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`)
//test until (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?until=1ns", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1ns", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test until (ok but no return)
w = lapi.RecordResponse("GET", "/v1/alerts?until=1m", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1m", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test until (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?until=1zuzu", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`)
//test simulated (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test simulated (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?simulated=false", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test has active decision
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=true", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test has active decision
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=false", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test has active decision (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String())
}
@ -357,32 +328,32 @@ func TestAlertListFilters(t *testing.T) {
func TestAlertBulkInsert(t *testing.T) {
lapi := SetupLAPITest(t)
//insert a bulk of 20 alerts to trigger bulk insert
lapi.InsertAlertFromFile("./tests/alert_bulk.json")
alertContent := GetAlertReaderFromFile("./tests/alert_bulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_bulk.json")
alertContent := GetAlertReaderFromFile(t, "./tests/alert_bulk.json")
w := lapi.RecordResponse("GET", "/v1/alerts", alertContent, "password")
w := lapi.RecordResponse(t, "GET", "/v1/alerts", alertContent, "password")
assert.Equal(t, 200, w.Code)
}
func TestListAlert(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// List Alert with invalid filter
w := lapi.RecordResponse("GET", "/v1/alerts?test=test", emptyBody, "password")
w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String())
assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
// List Alert
w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "crowdsecurity/test")
}
func TestCreateAlertErrors(t *testing.T) {
lapi := SetupLAPITest(t)
alertContent := GetAlertReaderFromFile("./tests/alert_sample.json")
alertContent := GetAlertReaderFromFile(t, "./tests/alert_sample.json")
//test invalid bearer
w := httptest.NewRecorder()
@ -403,7 +374,7 @@ func TestCreateAlertErrors(t *testing.T) {
func TestDeleteAlert(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Fail Delete Alert
w := httptest.NewRecorder()
@ -426,7 +397,7 @@ func TestDeleteAlert(t *testing.T) {
func TestDeleteAlertByID(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Fail Delete Alert
w := httptest.NewRecorder()
@ -454,25 +425,18 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"}
cfg.API.Server.ListenURI = "::8080"
server, err := NewServer(cfg.API.Server)
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
err = server.InitController()
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
router, err := server.Router()
if err != nil {
log.Fatal(err)
}
loginResp, err := LoginToTestAPI(router, cfg)
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
loginResp := LoginToTestAPI(t, router, cfg)
lapi := LAPI{
router: router,
loginResp: loginResp,
t: t,
}
assertAlertDeleteFailedFromIP := func(ip string) {
@ -498,17 +462,17 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
}
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeleteFailedFromIP("4.3.2.1")
assertAlertDeletedFromIP("1.2.3.4")
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeletedFromIP("1.2.4.0")
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeletedFromIP("1.2.4.1")
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeletedFromIP("1.2.4.255")
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeletedFromIP("127.0.0.1")
}

View file

@ -6,20 +6,14 @@ import (
"strings"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func TestAPIKey(t *testing.T) {
router, config, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITest(t)
APIKey := CreateTestBouncer(t, config.API.Server.DbConfig)
APIKey, err := CreateTestBouncer(config.API.Server.DbConfig)
if err != nil {
log.Fatal(err)
}
// Login with empty token
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader(""))
@ -27,7 +21,7 @@ func TestAPIKey(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String())
assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String())
// Login with invalid token
w = httptest.NewRecorder()
@ -37,7 +31,7 @@ func TestAPIKey(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String())
assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String())
// Login with valid token
w = httptest.NewRecorder()

View file

@ -36,6 +36,7 @@ import (
func getDBClient(t *testing.T) *database.Client {
t.Helper()
dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err)
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
@ -72,8 +73,9 @@ func getAPIC(t *testing.T) *apic {
}
}
func absDiff(a int, b int) (c int) {
if c = a - b; c < 0 {
func absDiff(a int, b int) int {
c := a - b
if c < 0 {
return -1 * c
}
@ -185,6 +187,7 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
func TestNewAPIC(t *testing.T) {
var testConfig *csconfig.OnlineApiClientCfg
setConfig := func() {
testConfig = &csconfig.OnlineApiClientCfg{
Credentials: &csconfig.ApiCredentialsCfg{
@ -199,6 +202,7 @@ func TestNewAPIC(t *testing.T) {
dbClient *database.Client
consoleConfig *csconfig.ConsoleConfig
}
tests := []struct {
name string
args args
@ -374,7 +378,6 @@ func TestAPICGetMetrics(t *testing.T) {
assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers)
assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines)
})
}
}
@ -403,6 +406,7 @@ func TestCreateAlertsForDecision(t *testing.T) {
type args struct {
decisions []*models.Decision
}
tests := []struct {
name string
args args
@ -489,6 +493,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
alerts []*models.Alert
decisions []*models.Decision
}
tests := []struct {
name string
args args
@ -544,26 +549,18 @@ func TestAPICWhitelists(t *testing.T) {
api := getAPIC(t)
//one whitelist on IP, one on CIDR
api.whitelists = &csconfig.CapiWhitelist{}
ipwl1 := "9.2.3.4"
ip := net.ParseIP(ipwl1)
api.whitelists.Ips = append(api.whitelists.Ips, ip)
ipwl1 = "7.2.3.4"
ip = net.ParseIP(ipwl1)
api.whitelists.Ips = append(api.whitelists.Ips, ip)
cidrwl1 := "13.2.3.0/24"
_, tnet, err := net.ParseCIDR(cidrwl1)
if err != nil {
t.Fatalf("unable to parse cidr : %s", err)
}
api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4"))
_, tnet, err := net.ParseCIDR("13.2.3.0/24")
require.NoError(t, err)
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
cidrwl1 = "11.2.3.0/24"
_, tnet, err = net.ParseCIDR(cidrwl1)
if err != nil {
t.Fatalf("unable to parse cidr : %s", err)
}
_, tnet, err = net.ParseCIDR("11.2.3.0/24")
require.NoError(t, err)
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
api.dbClient.Ent.Decision.Create().
SetOrigin(types.CAPIOrigin).
SetType("ban").
@ -663,12 +660,15 @@ func TestAPICWhitelists(t *testing.T) {
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder(
200, "1.2.3.6",
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder(
200, "1.2.3.7",
))
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -801,12 +801,15 @@ func TestAPICPullTop(t *testing.T) {
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder(
200, "1.2.3.6",
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder(
200, "1.2.3.7",
))
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -828,7 +831,8 @@ func TestAPICPullTop(t *testing.T) {
alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background())
validDecisions := api.dbClient.Ent.Decision.Query().Where(
decision.UntilGT(time.Now())).
AllX(context.Background())
AllX(context.Background(),
)
decisionScenarioFreq := make(map[string]int)
alertScenario := make(map[string]int)
@ -858,6 +862,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
200, jsonMarshalX(
modelscapi.GetDecisionsStreamResponse{
@ -887,10 +892,12 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(200, "1.2.3.4"), nil
})
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -916,6 +923,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
assert.NotEqual(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(304, ""), nil
})
err = api.PullTop(false)
require.NoError(t, err)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
@ -928,6 +936,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
// create a decision about to expire. It should force fetch
alertInstance := api.dbClient.Ent.Alert.
Create().
@ -975,10 +984,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(304, ""), nil
})
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -1005,6 +1016,7 @@ func TestAPICPullBlocklistCall(t *testing.T) {
assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(200, "1.2.3.4"), nil
})
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -1073,6 +1085,7 @@ func TestAPICPush(t *testing.T) {
Source: &models.Source{},
}
}
return alerts
}(),
},

View file

@ -13,11 +13,12 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/version"
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
@ -63,13 +64,14 @@ func LoadTestConfig(t *testing.T) csconfig.Config {
ShareCustomScenarios: new(bool),
},
}
apiConfig := csconfig.APICfg{
Server: &apiServerConfig,
}
config.API = &apiConfig
if err := config.API.Server.LoadProfiles(); err != nil {
log.Fatalf("failed to load profiles: %s", err)
}
err := config.API.Server.LoadProfiles()
require.NoError(t, err)
return config
}
@ -106,110 +108,89 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
Server: &apiServerConfig,
}
config.API = &apiConfig
if err := config.API.Server.LoadProfiles(); err != nil {
log.Fatalf("failed to load profiles: %s", err)
}
err := config.API.Server.LoadProfiles()
require.NoError(t, err)
return config
}
func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) {
func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) {
config := LoadTestConfig(t)
os.Remove("./ent")
apiServer, err := NewServer(config.API.Server)
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
log.Printf("Creating new API server")
gin.SetMode(gin.TestMode)
return apiServer, config, nil
return apiServer, config
}
func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config, error) {
apiServer, config, err := NewAPIServer(t)
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
err = apiServer.InitController()
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) {
apiServer, config := NewAPIServer(t)
err := apiServer.InitController()
require.NoError(t, err)
router, err := apiServer.Router()
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
return router, config, nil
return router, config
}
func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config, error) {
func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
config := LoadTestConfigForwardedFor(t)
os.Remove("./ent")
apiServer, err := NewServer(config.API.Server)
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
err = apiServer.InitController()
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
log.Printf("Creating new API server")
gin.SetMode(gin.TestMode)
router, err := apiServer.Router()
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
return router, config, nil
return router, config
}
func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error {
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
dbClient, err := database.NewClient(config)
if err != nil {
return fmt.Errorf("unable to create new database client: %s", err)
}
require.NoError(t, err)
if err := dbClient.ValidateMachine(machineID); err != nil {
return fmt.Errorf("unable to validate machine: %s", err)
}
return nil
err = dbClient.ValidateMachine(machineID)
require.NoError(t, err)
}
func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error) {
func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string {
dbClient, err := database.NewClient(config)
if err != nil {
return "", fmt.Errorf("unable to create new database client: %s", err)
}
require.NoError(t, err)
machines, err := dbClient.ListMachines()
if err != nil {
return "", fmt.Errorf("Unable to list machines: %s", err)
}
require.NoError(t, err)
for _, machine := range machines {
if machine.MachineId == machineID {
return machine.IpAddress, nil
return machine.IpAddress
}
}
return "", nil
return ""
}
func GetAlertReaderFromFile(path string) *strings.Reader {
func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader {
alertContentBytes, err := os.ReadFile(path)
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
alerts := make([]*models.Alert, 0)
if err = json.Unmarshal(alertContentBytes, &alerts); err != nil {
log.Fatal(err)
}
err = json.Unmarshal(alertContentBytes, &alerts)
require.NoError(t, err)
for _, alert := range alerts {
*alert.StartAt = time.Now().UTC().Format(time.RFC3339)
@ -217,74 +198,57 @@ func GetAlertReaderFromFile(path string) *strings.Reader {
}
alertContent, err := json.Marshal(alerts)
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
return strings.NewReader(string(alertContent))
}
func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) {
func readDecisionsGetResp(t *testing.T, resp *httptest.ResponseRecorder) ([]*models.Decision, int) {
var response []*models.Decision
if resp == nil {
return nil, 0, errors.New("response is nil")
}
err := json.Unmarshal(resp.Body.Bytes(), &response)
if err != nil {
return nil, resp.Code, err
}
require.NotNil(t, resp)
return response, resp.Code, nil
err := json.Unmarshal(resp.Body.Bytes(), &response)
require.NoError(t, err)
return response, resp.Code
}
func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) {
func readDecisionsErrorResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string]string, int) {
var response map[string]string
if resp == nil {
return nil, 0, errors.New("response is nil")
}
err := json.Unmarshal(resp.Body.Bytes(), &response)
if err != nil {
return nil, resp.Code, err
}
require.NotNil(t, resp)
return response, resp.Code, nil
err := json.Unmarshal(resp.Body.Bytes(), &response)
require.NoError(t, err)
return response, resp.Code
}
func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) {
func readDecisionsDeleteResp(t *testing.T, resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int) {
var response models.DeleteDecisionResponse
if resp == nil {
return nil, 0, errors.New("response is nil")
}
require.NotNil(t, resp)
err := json.Unmarshal(resp.Body.Bytes(), &response)
if err != nil {
return nil, resp.Code, err
}
require.NoError(t, err)
return &response, resp.Code, nil
return &response, resp.Code
}
func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) {
func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int) {
response := make(map[string][]*models.Decision)
if resp == nil {
return nil, 0, errors.New("response is nil")
}
require.NotNil(t, resp)
err := json.Unmarshal(resp.Body.Bytes(), &response)
if err != nil {
return nil, resp.Code, err
}
require.NoError(t, err)
return response, resp.Code, nil
return response, resp.Code
}
func CreateTestMachine(router *gin.Engine) (string, error) {
func CreateTestMachine(t *testing.T, router *gin.Engine) string {
b, err := json.Marshal(MachineTest)
if err != nil {
return "", fmt.Errorf("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w := httptest.NewRecorder()
@ -292,26 +256,20 @@ func CreateTestMachine(router *gin.Engine) (string, error) {
req.Header.Set("User-Agent", UserAgent)
router.ServeHTTP(w, req)
return body, nil
return body
}
func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) {
func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
dbClient, err := database.NewClient(config)
if err != nil {
log.Fatalf("unable to create new database client: %s", err)
}
require.NoError(t, err)
apiKey, err := middlewares.GenerateAPIKey(keyLength)
if err != nil {
return "", fmt.Errorf("unable to generate api key: %s", err)
}
require.NoError(t, err)
_, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
if err != nil {
return "", fmt.Errorf("unable to create blocker: %s", err)
}
require.NoError(t, err)
return apiKey, nil
return apiKey
}
func TestWithWrongDBConfig(t *testing.T) {
@ -334,10 +292,7 @@ func TestWithWrongFlushConfig(t *testing.T) {
}
func TestUnknownPath(t *testing.T) {
router, _, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, _ := NewAPITest(t)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
@ -384,24 +339,17 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
LogDir: tempDir,
DbConfig: &dbconfig,
}
lvl := log.DebugLevel
expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
expectedLines := []string{"/test42"}
cfg.LogLevel = &lvl
cfg.LogLevel = ptr.Of(log.DebugLevel)
// Configure logging
if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil {
t.Fatal(err)
}
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
require.NoError(t, err)
api, err := NewServer(&cfg)
if err != nil {
t.Fatalf("failed to create api : %s", err)
}
if api == nil {
t.Fatalf("failed to create api #2 is nbill")
}
require.NoError(t, err)
require.NotNil(t, api)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test42", nil)
@ -413,14 +361,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
//check file content
data, err := os.ReadFile(expectedFile)
if err != nil {
t.Fatalf("failed to read file : %s", err)
}
require.NoError(t, err)
for _, expectedStr := range expectedLines {
if !strings.Contains(string(data), expectedStr) {
t.Fatalf("expected %s in %s", expectedStr, string(data))
}
assert.Contains(t, string(data), expectedStr)
}
}
@ -446,35 +390,29 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
LogDir: tempDir,
DbConfig: &dbconfig,
}
lvl := log.ErrorLevel
expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
cfg.LogLevel = &lvl
cfg.LogLevel = ptr.Of(log.ErrorLevel)
// Configure logging
if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil {
t.Fatal(err)
}
api, err := NewServer(&cfg)
if err != nil {
t.Fatalf("failed to create api : %s", err)
}
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
require.NoError(t, err)
if api == nil {
t.Fatalf("failed to create api #2 is nbill")
}
api, err := NewServer(&cfg)
require.NoError(t, err)
require.NotNil(t, api)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test42", nil)
req.Header.Set("User-Agent", UserAgent)
api.router.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
assert.Equal(t, http.StatusNotFound, w.Code)
//wait for the request to happen
time.Sleep(500 * time.Millisecond)
//check file content
x, err := os.ReadFile(expectedFile)
if err == nil && len(x) > 0 {
t.Fatalf("file should be empty, got '%s'", x)
if err == nil {
require.Empty(t, x)
}
os.Remove("./crowdsec.log")

View file

@ -4,7 +4,6 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
@ -16,23 +15,22 @@ func TestDeleteDecisionRange(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_minibulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json")
// delete by ip wrong
w := lapi.RecordResponse("DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD)
w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by range
w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String())
// delete by range : ensure it was already deleted
w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
}
@ -41,23 +39,23 @@ func TestDeleteDecisionFilter(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_minibulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json")
// delete by ip wrong
w := lapi.RecordResponse("DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD)
w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by ip good
w = lapi.RecordResponse("DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
// delete by scope/value
w = lapi.RecordResponse("DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
}
@ -66,17 +64,17 @@ func TestDeleteDecisionFilterByScenario(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_minibulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json")
// delete by wrong scenario
w := lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD)
w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by scenario good
w = lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String())
}
@ -85,14 +83,13 @@ func TestGetDecisionFilters(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_minibulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json")
// Get Decision
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err := readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code := readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 2)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
@ -104,10 +101,9 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision : type filter
w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?type=ban", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code = readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 2)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
@ -122,10 +118,9 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision : scope/value
w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code = readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 1)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
@ -137,10 +132,9 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision : ip filter
w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code = readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 1)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
@ -151,10 +145,9 @@ func TestGetDecisionFilters(t *testing.T) {
// assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`)
// Get decision : by range
w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code = readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 2)
assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179")
@ -165,13 +158,12 @@ func TestGetDecision(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Get Decision
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err := readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code := readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 3)
/*decisions get doesn't perform deduplication*/
@ -188,7 +180,7 @@ func TestGetDecision(t *testing.T) {
assert.Equal(t, int64(3), decisions[2].ID)
// Get Decision with invalid filter. It should ignore this filter
w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?test=test", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
assert.Len(t, decisions, 3)
}
@ -197,49 +189,43 @@ func TestDeleteDecisionByID(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
//Have one alerts
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err := readDecisionsStreamResp(w)
require.NoError(t, err)
w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code := readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
// Delete alert with Invalid ID
w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/test", emptyBody, PASSWORD)
assert.Equal(t, 400, w.Code)
errResp, _, err := readDecisionsErrorResp(w)
require.NoError(t, err)
errResp, _ := readDecisionsErrorResp(t, w)
assert.Equal(t, "decision_id must be valid integer", errResp["message"])
// Delete alert with ID that not exist
w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/100", emptyBody, PASSWORD)
assert.Equal(t, 500, w.Code)
errResp, _, err = readDecisionsErrorResp(w)
require.NoError(t, err)
errResp, _ = readDecisionsErrorResp(t, w)
assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"])
//Have one alerts
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
// Delete alert with valid ID
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
resp, _, err := readDecisionsDeleteResp(w)
require.NoError(t, err)
resp, _ := readDecisionsDeleteResp(t, w)
assert.Equal(t, "1", resp.NbDeleted)
//Have one alert (because we delete an alert that has dup targets)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
@ -249,20 +235,18 @@ func TestDeleteDecision(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Delete alert with Invalid filter
w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD)
w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD)
assert.Equal(t, 500, w.Code)
errResp, _, err := readDecisionsErrorResp(w)
require.NoError(t, err)
errResp, _ := readDecisionsErrorResp(t, w)
assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"])
// Delete all alert
w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
resp, _, err := readDecisionsDeleteResp(w)
require.NoError(t, err)
resp, _ := readDecisionsDeleteResp(t, w)
assert.Equal(t, "3", resp.NbDeleted)
}
@ -271,12 +255,11 @@ func TestStreamStartDecisionDedup(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Get Stream, we only get one decision (the longest one)
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err := readDecisionsStreamResp(w)
require.NoError(t, err)
w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code := readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
@ -285,13 +268,12 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
// id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip
w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/3", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
// Get Stream, we only get one decision (the longest one, id=2)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
@ -300,13 +282,12 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
// We delete another decision, yet don't receive it in stream, since there's another decision on same IP
w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/2", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
// And get the remaining decision (1)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
@ -315,13 +296,12 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
// We delete the last decision, we receive the delete order
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
//and now we only get a deleted decision
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions["deleted"], 1)
assert.Equal(t, int64(1), decisions["deleted"][0].ID)

View file

@ -10,9 +10,9 @@ import (
func TestHeartBeat(t *testing.T) {
lapi := SetupLAPITest(t)
w := lapi.RecordResponse(http.MethodGet, "/v1/heartbeat", emptyBody, "password")
w := lapi.RecordResponse(t, http.MethodGet, "/v1/heartbeat", emptyBody, "password")
assert.Equal(t, 200, w.Code)
w = lapi.RecordResponse("POST", "/v1/heartbeat", emptyBody, "password")
w = lapi.RecordResponse(t, "POST", "/v1/heartbeat", emptyBody, "password")
assert.Equal(t, 405, w.Code)
}

View file

@ -6,20 +6,13 @@ import (
"strings"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func TestLogin(t *testing.T) {
router, config, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITest(t)
body, err := CreateTestMachine(router)
if err != nil {
log.Fatalln(err)
}
body := CreateTestMachine(t, router)
// Login with machine not validated yet
w := httptest.NewRecorder()
@ -28,16 +21,16 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"machine test not validated\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"machine test not validated"}`, w.Body.String())
// Login with machine not exist
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\", \"password\": \"test1\"}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"ent: machine not found\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String())
// Login with invalid body
w = httptest.NewRecorder()
@ -46,31 +39,28 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"missing: invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"missing: invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Login with invalid format
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\"}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"validation failure list:\\npassword in body is required\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String())
//Validate machine
err = ValidateMachine("test", config.API.Server.DbConfig)
if err != nil {
log.Fatalln(err)
}
ValidateMachine(t, "test", config.API.Server.DbConfig)
// Login with invalid password
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test1\"}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"incorrect Username or Password\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String())
// Login with valid machine
w = httptest.NewRecorder()
@ -79,16 +69,16 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "\"token\"")
assert.Contains(t, w.Body.String(), "\"expire\"")
assert.Contains(t, w.Body.String(), `"token"`)
assert.Contains(t, w.Body.String(), `"expire"`)
// Login with valid machine + scenarios
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test\", \"scenarios\": [\"crowdsecurity/test\", \"crowdsecurity/test2\"]}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "\"token\"")
assert.Contains(t, w.Body.String(), "\"expire\"")
assert.Contains(t, w.Body.String(), `"token"`)
assert.Contains(t, w.Body.String(), `"expire"`)
}

View file

@ -7,15 +7,12 @@ import (
"strings"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateMachine(t *testing.T) {
router, _, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, _ := NewAPITest(t)
// Create machine with invalid format
w := httptest.NewRecorder()
@ -24,22 +21,21 @@ func TestCreateMachine(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 400, w.Code)
assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String())
assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Create machine with invalid input
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("{\"test\": \"test\"}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"validation failure list:\\nmachine_id in body is required\\npassword in body is required\"}", w.Body.String())
assert.Equal(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String())
// Create machine
b, err := json.Marshal(MachineTest)
if err != nil {
log.Fatal("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w = httptest.NewRecorder()
@ -52,16 +48,12 @@ func TestCreateMachine(t *testing.T) {
}
func TestCreateMachineWithForwardedFor(t *testing.T) {
router, config, err := NewAPITestForwardedFor(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITestForwardedFor(t)
router.TrustedPlatform = "X-Real-IP"
// Create machine
b, err := json.Marshal(MachineTest)
if err != nil {
log.Fatal("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w := httptest.NewRecorder()
@ -73,25 +65,18 @@ func TestCreateMachineWithForwardedFor(t *testing.T) {
assert.Equal(t, 201, w.Code)
assert.Equal(t, "", w.Body.String())
ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig)
if err != nil {
log.Fatalf("Could not get machine IP : %s", err)
}
ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig)
assert.Equal(t, "1.1.1.1", ip)
}
func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
router, config, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITest(t)
// Create machine
b, err := json.Marshal(MachineTest)
if err != nil {
log.Fatal("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w := httptest.NewRecorder()
@ -103,26 +88,20 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
assert.Equal(t, 201, w.Code)
assert.Equal(t, "", w.Body.String())
ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig)
if err != nil {
log.Fatalf("Could not get machine IP : %s", err)
}
ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig)
//For some reason, the IP is empty when running tests
//if no forwarded-for headers are present
assert.Equal(t, "", ip)
}
func TestCreateMachineWithoutForwardedFor(t *testing.T) {
router, config, err := NewAPITestForwardedFor(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITestForwardedFor(t)
// Create machine
b, err := json.Marshal(MachineTest)
if err != nil {
log.Fatal("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w := httptest.NewRecorder()
@ -133,25 +112,17 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) {
assert.Equal(t, 201, w.Code)
assert.Equal(t, "", w.Body.String())
ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig)
if err != nil {
log.Fatalf("Could not get machine IP : %s", err)
}
ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig)
//For some reason, the IP is empty when running tests
//if no forwarded-for headers are present
assert.Equal(t, "", ip)
}
func TestCreateMachineAlreadyExist(t *testing.T) {
router, _, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, _ := NewAPITest(t)
body, err := CreateTestMachine(router)
if err != nil {
log.Fatalln(err)
}
body := CreateTestMachine(t, router)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body))
@ -164,5 +135,5 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String())
assert.Equal(t, `{"message":"user 'test': user already exist"}`, w.Body.String())
}