From 12a4a5fb1471d81d2a5ef4363ab0bfcfcb35291b Mon Sep 17 00:00:00 2001 From: JDEV Date: Fri, 17 Feb 2023 14:57:46 +0100 Subject: [PATCH] CAPI error code handling tests (#2027) * Registration mocked error cases * Authentication mock error cases * mini facto * check that getMEtric still has bouncers/machines keys in output even with empty collections * fixed defer body close(), no need to defer and fprint arg * fix fatal call --------- Co-authored-by: jdv --- pkg/apiclient/auth_service_test.go | 207 ++++++++++++++++++----------- pkg/apiserver/apic_test.go | 10 ++ 2 files changed, 140 insertions(+), 77 deletions(-) diff --git a/pkg/apiclient/auth_service_test.go b/pkg/apiclient/auth_service_test.go index c844afe8e..1e3e83e04 100644 --- a/pkg/apiclient/auth_service_test.go +++ b/pkg/apiclient/auth_service_test.go @@ -3,7 +3,9 @@ package apiclient import ( "bytes" "context" + "encoding/json" "fmt" + "io" "net/http" "net/url" "testing" @@ -14,6 +16,99 @@ import ( "github.com/stretchr/testify/assert" ) +type BasicMockPayload struct { + MachineID string `json:"machine_id"` + Password string `json:"password"` +} + +func getLoginsForMockErrorCases() map[string]int { + loginsForMockErrorCases := map[string]int{ + "login_400": http.StatusBadRequest, + "login_409": http.StatusConflict, + "login_500": http.StatusInternalServerError, + } + + return loginsForMockErrorCases +} + +func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { + loginsForMockErrorCases := getLoginsForMockErrorCases() + mux.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, "POST") + buf := new(bytes.Buffer) + _, _ = buf.ReadFrom(r.Body) + newStr := buf.String() + + var payload BasicMockPayload + err := json.Unmarshal([]byte(newStr), &payload) + if err != nil || payload.MachineID == "" || payload.Password == "" { + log.Printf("Bad payload") + w.WriteHeader(http.StatusBadRequest) + } + + responseBody := "" + responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID] + + if !hasFoundErrorMock { + responseCode = http.StatusOK + responseBody = `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}` + } else { + responseBody = fmt.Sprintf("Error %d", responseCode) + } + log.Printf("MockServerReceived > %s // Login : [%s] => Mux response [%d]", newStr, payload.MachineID, responseCode) + + w.WriteHeader(responseCode) + fmt.Fprintf(w, `%s`, responseBody) + }) +} + +/** + * Test the RegisterClient function + * Making sure it handles the different response code potentially coming from CAPI properly + * 200 => OK + * 400, 409, 500 => Error + */ +func TestWatcherRegister(t *testing.T) { + + log.SetLevel(log.DebugLevel) + + mux, urlx, teardown := setup() + defer teardown() + //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} + initBasicMuxMock(t, mux, "/watchers") + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") + if err != nil { + t.Fatalf("parsing api url: %s", apiURL) + } + + // Valid Registration : should retrieve the client and no err + clientconfig := Config{ + MachineID: "test_login", + Password: "test_password", + UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), + URL: apiURL, + VersionPrefix: "v1", + } + client, err := RegisterClient(&clientconfig, &http.Client{}) + if client == nil || err != nil { + t.Fatalf("while registering client : %s", err) + } + log.Printf("->%T", client) + + // Testing error handling on Registration (400, 409, 500): should retrieve an error + errorCodesToTest := [3]int{http.StatusBadRequest, http.StatusConflict, http.StatusInternalServerError} + for _, errorCodeToTest := range errorCodesToTest { + clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) + client, err = RegisterClient(&clientconfig, &http.Client{}) + if client != nil || err == nil { + t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest) + } else { + log.Printf("The RegisterClient function handled the error code %d as expected \n\r", errorCodeToTest) + } + } +} + func TestWatcherAuth(t *testing.T) { log.SetLevel(log.DebugLevel) @@ -22,23 +117,7 @@ func TestWatcherAuth(t *testing.T) { defer teardown() //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} - mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "POST") - buf := new(bytes.Buffer) - _, _ = buf.ReadFrom(r.Body) - newStr := buf.String() - log.Printf("--> %s", newStr) - if newStr == `{"machine_id":"test_login","password":"test_password","scenarios":["crowdsecurity/test"]} -` { - log.Printf("ok cool") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`) - } else { - w.WriteHeader(http.StatusForbidden) - log.Printf("badbad") - fmt.Fprintf(w, `{"message":"access forbidden"}`) - } - }) + initBasicMuxMock(t, mux, "/watchers/login") log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") if err != nil { @@ -46,7 +125,7 @@ func TestWatcherAuth(t *testing.T) { } //ok auth - mycfg := &Config{ + clientConfig := &Config{ MachineID: "test_login", Password: "test_password", UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), @@ -54,77 +133,51 @@ func TestWatcherAuth(t *testing.T) { VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } - client, err := NewClient(mycfg) + client, err := NewClient(clientConfig) if err != nil { t.Fatalf("new api client: %s", err) } _, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ - MachineID: &mycfg.MachineID, - Password: &mycfg.Password, - Scenarios: mycfg.Scenarios, + MachineID: &clientConfig.MachineID, + Password: &clientConfig.Password, + Scenarios: clientConfig.Scenarios, }) if err != nil { t.Fatalf("unexpect auth err 0: %s", err) } - //bad auth - mycfg = &Config{ - MachineID: "BADtest_login", - Password: "BADtest_password", - UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), - URL: apiURL, - VersionPrefix: "v1", - Scenarios: []string{"crowdsecurity/test"}, + // Testing error handling on AuthenticateWatcher (400, 409): should retrieve an error + // Not testing 500 because it loops and try to re-autehnticate. But you can test it manually by adding it in array + errorCodesToTest := [2]int{http.StatusBadRequest, http.StatusConflict} + for _, errorCodeToTest := range errorCodesToTest { + clientConfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) + client, err := NewClient(clientConfig) + + if err != nil { + t.Fatalf("new api client: %s", err) + } + + var resp *Response + _, resp, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + MachineID: &clientConfig.MachineID, + Password: &clientConfig.Password, + }) + + if err == nil { + resp.Response.Body.Close() + bodyBytes, err := io.ReadAll(resp.Response.Body) + if err != nil { + t.Fatalf("error while reading body: %s", err.Error()) + } + + log.Printf(string(bodyBytes)) + t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest) + } else { + log.Printf("The AuthenticateWatcher function handled the error code %d as expected \n\r", errorCodeToTest) + } } - client, err = NewClient(mycfg) - - if err != nil { - t.Fatalf("new api client: %s", err) - } - - _, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ - MachineID: &mycfg.MachineID, - Password: &mycfg.Password, - }) - assert.Contains(t, err.Error(), "API error: access forbidden") - -} - -func TestWatcherRegister(t *testing.T) { - - log.SetLevel(log.DebugLevel) - - mux, urlx, teardown := setup() - defer teardown() - //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} - - mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) { - testMethod(t, r, "POST") - buf := new(bytes.Buffer) - _, _ = buf.ReadFrom(r.Body) - newStr := buf.String() - assert.Equal(t, newStr, `{"machine_id":"test_login","password":"test_password"} -`) - w.WriteHeader(http.StatusOK) - }) - log.Printf("URL is %s", urlx) - apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } - client, err := RegisterClient(&Config{ - MachineID: "test_login", - Password: "test_password", - UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), - URL: apiURL, - VersionPrefix: "v1", - }, &http.Client{}) - if err != nil { - t.Fatalf("while registering client : %s", err) - } - log.Printf("->%T", client) } func TestWatcherUnregister(t *testing.T) { diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 9168d59ef..fec7d8182 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -286,6 +286,16 @@ func TestAPICGetMetrics(t *testing.T) { bouncers []string expectedMetric *models.Metrics }{ + { + name: "no bouncers nor machines should still have bouncers/machines keys in output", + machineIDs: []string{}, + bouncers: []string{}, + expectedMetric: &models.Metrics{ + ApilVersion: types.StrPtr(cwversion.VersionStr()), + Bouncers: []*models.MetricsBouncerInfo{}, + Machines: []*models.MetricsAgentInfo{}, + }, + }, { name: "simple", machineIDs: []string{"a", "b", "c"},