From 437a97510ae8b60e1f7cb4ce6cc23ad82ea01f98 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Wed, 10 Jan 2024 12:00:22 +0100 Subject: [PATCH] apiclient: handle 0-byte error response (#2716) * apiclient: correctly handle 0-byte response * lint --- pkg/apiclient/auth.go | 20 +++++++++++++------- pkg/apiclient/client.go | 6 +++--- pkg/apiclient/decisions_service.go | 4 +++- pkg/apiclient/heartbeat.go | 2 ++ pkg/apiserver/apiserver.go | 4 ++++ pkg/apiserver/controllers/controller.go | 3 ++- pkg/apiserver/controllers/v1/alerts.go | 21 +++++++++++++++------ 7 files changed, 42 insertions(+), 18 deletions(-) diff --git a/pkg/apiclient/auth.go b/pkg/apiclient/auth.go index 86cdc7736..163e96718 100644 --- a/pkg/apiclient/auth.go +++ b/pkg/apiclient/auth.go @@ -3,6 +3,7 @@ package apiclient import ( "bytes" "encoding/json" + "errors" "fmt" "io" "math/rand" @@ -13,7 +14,6 @@ import ( "time" "github.com/go-openapi/strfmt" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/fflag" @@ -52,10 +52,12 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) { dump, _ := httputil.DumpRequest(req, true) log.Tracef("auth-api request: %s", string(dump)) } + // Make the HTTP request. resp, err := t.transport().RoundTrip(req) if err != nil { log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err) + return resp, err } @@ -115,10 +117,12 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) for i := 0; i < maxAttempts; i++ { if i > 0 { if r.withBackOff { + //nolint:gosec backoff += 10 + rand.Intn(20) } log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts) + select { case <-req.Context().Done(): return resp, req.Context().Err() @@ -134,8 +138,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) resp, err = r.next.RoundTrip(clonedReq) if err != nil { - left := maxAttempts - i - 1 - if left > 0 { + if left := maxAttempts - i - 1; left > 0 { log.Errorf("error while performing request: %s; %d retries left", err, left) } @@ -177,7 +180,7 @@ func (t *JWTTransport) refreshJwtToken() error { log.Debugf("scenarios list updated for '%s'", *t.MachineID) } - var auth = models.WatcherAuthRequest{ + auth := models.WatcherAuthRequest{ MachineID: t.MachineID, Password: t.Password, Scenarios: t.Scenarios, @@ -264,13 +267,14 @@ func (t *JWTTransport) refreshJwtToken() error { // RoundTrip implements the RoundTripper interface. func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI + // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI // we use a mutex to avoid this - //We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) + // We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) t.refreshTokenMutex.Lock() if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) { if err := t.refreshJwtToken(); err != nil { t.refreshTokenMutex.Unlock() + return nil, err } } @@ -296,8 +300,9 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { } if err != nil { - /*we had an error (network error for example, or 401 because token is refused), reset the token ?*/ + // we had an error (network error for example, or 401 because token is refused), reset the token ? t.Token = "" + return resp, fmt.Errorf("performing jwt auth: %w", err) } @@ -355,6 +360,7 @@ func cloneRequest(r *http.Request) *http.Request { *r2 = *r // deep copy of the Header r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { r2.Header[k] = append([]string(nil), s...) } diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index 75bc52881..b183a8c79 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -74,6 +74,7 @@ func NewClient(config *Config) (*ApiClient, error) { VersionPrefix: config.VersionPrefix, UpdateScenario: config.UpdateScenario, } + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig.RootCAs = CaCertPool @@ -180,8 +181,7 @@ func (e *ErrorResponse) Error() string { } func newResponse(r *http.Response) *Response { - response := &Response{Response: r} - return response + return &Response{Response: r} } func CheckResponse(r *http.Response) error { @@ -192,7 +192,7 @@ func CheckResponse(r *http.Response) error { errorResponse := &ErrorResponse{} data, err := io.ReadAll(r.Body) - if err == nil && data != nil { + if err == nil && len(data)>0 { err := json.Unmarshal(data, errorResponse) if err != nil { return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index 89e6eff92..a3f02c0ef 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -183,7 +183,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl req = req.WithContext(ctx) log.Debugf("[URL] %s %s", req.Method, req.URL) - // we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc + // we don't use client_http Do method because we need the reader and is not provided. + // We would be forced to use Pipe and goroutine, etc resp, err := client.Do(req) if resp != nil && resp.Body != nil { defer resp.Body.Close() @@ -216,6 +217,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl if resp.StatusCode != http.StatusOK { log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL) + return nil, false, nil } diff --git a/pkg/apiclient/heartbeat.go b/pkg/apiclient/heartbeat.go index bf61b8d2e..77e0ecc2e 100644 --- a/pkg/apiclient/heartbeat.go +++ b/pkg/apiclient/heartbeat.go @@ -38,11 +38,13 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) { select { case <-hbTimer.C: log.Debug("heartbeat: sending heartbeat") + ok, resp, err := h.Ping(ctx) if err != nil { log.Errorf("heartbeat error : %s", err) continue } + resp.Response.Body.Close() if resp.Response.StatusCode != http.StatusOK { log.Errorf("heartbeat unexpected return code : %d", resp.Response.StatusCode) diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 11d0c3eaa..638ac2c65 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -307,6 +307,7 @@ func (s *APIServer) Run(apiReady chan bool) error { log.Errorf("capi push: %s", err) return err } + return nil }) @@ -315,6 +316,7 @@ func (s *APIServer) Run(apiReady chan bool) error { log.Errorf("capi pull: %s", err) return err } + return nil }) @@ -328,6 +330,7 @@ func (s *APIServer) Run(apiReady chan bool) error { log.Errorf("papi pull: %s", err) return err } + return nil }) @@ -336,6 +339,7 @@ func (s *APIServer) Run(apiReady chan bool) error { log.Errorf("capi decisions sync: %s", err) return err } + return nil }) } else { diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index 5794b40d3..bab196512 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -55,6 +55,7 @@ func serveHealth() http.HandlerFunc { // no caching required health.WithDisabledCache(), ) + return health.NewHandler(checker) } @@ -76,6 +77,7 @@ func (c *Controller) NewV1() error { if err != nil { return err } + c.Router.GET("/health", gin.WrapF(serveHealth())) c.Router.Use(v1.PrometheusMiddleware()) c.Router.HandleMethodNotAllowed = true @@ -104,7 +106,6 @@ func (c *Controller) NewV1() error { jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions) jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById) jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat) - } apiKeyAuth := groupV1.Group("") diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 10841ce45..424c20af6 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -22,7 +22,6 @@ import ( ) func FormatOneAlert(alert *ent.Alert) *models.Alert { - var outputAlert models.Alert startAt := alert.StartedAt.String() StopAt := alert.StoppedAt.String() @@ -31,7 +30,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { machineID = alert.Edges.Owner.MachineId } - outputAlert = models.Alert{ + outputAlert := models.Alert{ ID: int64(alert.ID), MachineID: machineID, CreatedAt: alert.CreatedAt.Format(time.RFC3339), @@ -58,23 +57,27 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { Longitude: alert.SourceLongitude, }, } + for _, eventItem := range alert.Edges.Events { var Metas models.Meta timestamp := eventItem.Time.String() if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil { log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err) } + outputAlert.Events = append(outputAlert.Events, &models.Event{ Timestamp: ×tamp, Meta: Metas, }) } + for _, metaItem := range alert.Edges.Metas { outputAlert.Meta = append(outputAlert.Meta, &models.MetaItems0{ Key: metaItem.Key, Value: metaItem.Value, }) } + for _, decisionItem := range alert.Edges.Decisions { duration := decisionItem.Until.Sub(time.Now().UTC()).String() outputAlert.Decisions = append(outputAlert.Decisions, &models.Decision{ @@ -88,6 +91,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { ID: int64(decisionItem.ID), }) } + return &outputAlert } @@ -97,6 +101,7 @@ func FormatAlerts(result []*ent.Alert) models.AddAlertsRequest { for _, alertItem := range result { data = append(data, FormatOneAlert(alertItem)) } + return data } @@ -107,6 +112,7 @@ func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uin select { case c.PluginChannel <- csplugin.ProfileAlert{ProfileID: profileID, Alert: alert}: log.Debugf("alert sent to Plugin channel") + break RETRY default: log.Warningf("Cannot send alert to Plugin channel (try: %d)", try) @@ -133,7 +139,6 @@ func normalizeScope(scope string) string { // CreateAlert writes the alerts received in the body to the database func (c *Controller) CreateAlert(gctx *gin.Context) { - var input models.AddAlertsRequest claims := jwt.ExtractClaims(gctx) @@ -144,13 +149,16 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } + if err := input.Validate(strfmt.Default); err != nil { c.HandleDBErrors(gctx, err) return } + stopFlush := false + for _, alert := range input { - //normalize scope for alert.Source and decisions + // normalize scope for alert.Source and decisions if alert.Source.Scope != nil { *alert.Source.Scope = normalizeScope(*alert.Source.Scope) } @@ -161,15 +169,16 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { } alert.MachineID = machineID - //generate uuid here for alert + // generate uuid here for alert alert.UUID = uuid.NewString() - //if coming from cscli, alert already has decisions + // if coming from cscli, alert already has decisions if len(alert.Decisions) != 0 { //alert already has a decision (cscli decisions add etc.), generate uuid here for _, decision := range alert.Decisions { decision.UUID = uuid.NewString() } + for pIdx, profile := range c.Profiles { _, matched, err := profile.EvaluateProfile(alert) if err != nil {