diff --git a/pkg/apiclient/alerts_service.go b/pkg/apiclient/alerts_service.go index 1d0a4ebd1..eb41452ea 100644 --- a/pkg/apiclient/alerts_service.go +++ b/pkg/apiclient/alerts_service.go @@ -49,15 +49,15 @@ type AlertsDeleteOpts struct { } func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) { - var addedIds models.AddAlertsResponse - u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) - req, err := s.client.NewRequest(http.MethodPost, u, &alerts) + req, err := s.client.NewRequest(http.MethodPost, u, &alerts) if err != nil { return nil, nil, err } + var addedIds models.AddAlertsResponse + resp, err := s.client.Do(ctx, req, &addedIds) if err != nil { return nil, resp, err @@ -68,22 +68,16 @@ func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) // to demo query arguments func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.GetAlertsResponse, *Response, error) { - var ( - alerts models.GetAlertsResponse - URI string - ) - u := fmt.Sprintf("%s/alerts", s.client.URLPrefix) - params, err := qs.Values(opts) + params, err := qs.Values(opts) if err != nil { return nil, nil, fmt.Errorf("building query: %w", err) } + URI := u if len(params) > 0 { - URI = fmt.Sprintf("%s?%s", u, params.Encode()) - } else { - URI = u + URI = fmt.Sprintf("%s?%s", URI, params.Encode()) } req, err := s.client.NewRequest(http.MethodGet, URI, nil) @@ -91,6 +85,8 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models. return nil, nil, fmt.Errorf("building request: %w", err) } + alerts := models.GetAlertsResponse{} + resp, err := s.client.Do(ctx, req, &alerts) if err != nil { return nil, resp, fmt.Errorf("performing request: %w", err) @@ -101,8 +97,6 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models. // to demo query arguments func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*models.DeleteAlertsResponse, *Response, error) { - var alerts models.DeleteAlertsResponse - params, err := qs.Values(opts) if err != nil { return nil, nil, err @@ -115,6 +109,8 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod return nil, nil, err } + alerts := models.DeleteAlertsResponse{} + resp, err := s.client.Do(ctx, req, &alerts) if err != nil { return nil, resp, err @@ -124,8 +120,6 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod } func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) { - var alerts models.DeleteAlertsResponse - u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID) req, err := s.client.NewRequest(http.MethodDelete, u, nil) @@ -133,6 +127,8 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models. return nil, nil, err } + alerts := models.DeleteAlertsResponse{} + resp, err := s.client.Do(ctx, req, &alerts) if err != nil { return nil, resp, err @@ -142,8 +138,6 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models. } func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) { - var alert models.Alert - u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID) req, err := s.client.NewRequest(http.MethodGet, u, nil) @@ -151,6 +145,8 @@ func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert return nil, nil, err } + alert := models.Alert{} + resp, err := s.client.Do(ctx, req, &alert) if err != nil { return nil, nil, err diff --git a/pkg/apiclient/auth.go b/pkg/apiclient/auth.go index 163e96718..02a2f5ada 100644 --- a/pkg/apiclient/auth.go +++ b/pkg/apiclient/auth.go @@ -125,7 +125,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) select { case <-req.Context().Done(): - return resp, req.Context().Err() + return nil, req.Context().Err() case <-time.After(time.Duration(backoff) * time.Second): } } @@ -135,8 +135,8 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) } clonedReq := cloneRequest(req) - resp, err = r.next.RoundTrip(clonedReq) + resp, err = r.next.RoundTrip(clonedReq) if err != nil { if left := maxAttempts - i - 1; left > 0 { log.Errorf("error while performing request: %s; %d retries left", err, left) @@ -171,10 +171,11 @@ type JWTTransport struct { func (t *JWTTransport) refreshJwtToken() error { var err error + if t.UpdateScenario != nil { t.Scenarios, err = t.UpdateScenario() if err != nil { - return fmt.Errorf("can't update scenario list: %s", err) + return fmt.Errorf("can't update scenario list: %w", err) } log.Debugf("scenarios list updated for '%s'", *t.MachineID) @@ -186,8 +187,6 @@ func (t *JWTTransport) refreshJwtToken() error { Scenarios: t.Scenarios, } - var response models.WatcherAuthResponse - /* we don't use the main client, so let's build the body */ @@ -250,6 +249,8 @@ func (t *JWTTransport) refreshJwtToken() error { } } + var response models.WatcherAuthResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { return fmt.Errorf("unable to decode response: %w", err) } @@ -300,7 +301,7 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { } if err != nil { - // we had an error (network error for example, or 401 because token is refused), reset the token ? + // we had an error (network error for example, or 401 because token is refused), reset the token? t.Token = "" return resp, fmt.Errorf("performing jwt auth: %w", err) @@ -324,14 +325,13 @@ func (t *JWTTransport) ResetToken() { t.refreshTokenMutex.Unlock() } +// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded. func (t *JWTTransport) transport() http.RoundTripper { - var transport http.RoundTripper - if t.Transport != nil { - transport = t.Transport - } else { + transport := t.Transport + if transport == nil { transport = http.DefaultTransport } - // a round tripper that retries once when the status is unauthorized and 5 times when infrastructure is overloaded + return &retryRoundTripper{ next: &retryRoundTripper{ next: transport, diff --git a/pkg/apiclient/client_http.go b/pkg/apiclient/client_http.go index 5222ad770..0240618f5 100644 --- a/pkg/apiclient/client_http.go +++ b/pkg/apiclient/client_http.go @@ -94,7 +94,7 @@ func (c *ApiClient) Do(ctx context.Context, req *http.Request, v interface{}) (* if log.GetLevel() >= log.DebugLevel { for k, v := range resp.Header { - log.Debugf("[headers] %s : %s", k, v) + log.Debugf("[headers] %s: %s", k, v) } dump, err := httputil.DumpResponse(resp, true) diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index a3f02c0ef..388a870f9 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -3,11 +3,11 @@ package apiclient import ( "bufio" "context" + "errors" "fmt" "net/http" qs "github.com/google/go-querystring/query" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -61,8 +61,6 @@ type DecisionsDeleteOpts struct { // to demo query arguments func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*models.GetDecisionsResponse, *Response, error) { - var decisions models.GetDecisionsResponse - params, err := qs.Values(opts) if err != nil { return nil, nil, err @@ -75,6 +73,8 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m return nil, nil, err } + var decisions models.GetDecisionsResponse + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err @@ -84,13 +84,13 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m } func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) { - var decisions models.DecisionsStreamResponse - req, err := s.client.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, nil, err } + var decisions models.DecisionsStreamResponse + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err @@ -100,7 +100,7 @@ func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*m } func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi.GetDecisionsStreamResponseNewItem) []*models.Decision { - var decisions []*models.Decision + decisions := make([]*models.Decision, 0) for _, decisionsGroup := range decisionsGroups { partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions)) @@ -122,11 +122,6 @@ func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi. } func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) { - var ( - decisions modelscapi.GetDecisionsStreamResponse - v2Decisions models.DecisionsStreamResponse - ) - scenarioDeleted := "deleted" durationDeleted := "1h" @@ -135,11 +130,14 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m return nil, nil, err } + decisions := modelscapi.GetDecisionsStreamResponse{} + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err } + v2Decisions := models.DecisionsStreamResponse{} v2Decisions.New = s.GetDecisionsFromGroups(decisions.New) for _, decisionsGroup := range decisions.Deleted { @@ -183,6 +181,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl req = req.WithContext(ctx) log.Debugf("[URL] %s %s", req.Method, req.URL) + // we don't use client_http Do method because we need the reader and is not provided. // We would be forced to use Pipe and goroutine, etc resp, err := client.Do(req) @@ -247,11 +246,11 @@ func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOp return nil, nil, err } - if s.client.URLPrefix == "v3" { - return s.FetchV3Decisions(ctx, u) - } else { + if s.client.URLPrefix != "v3" { return s.FetchV2Decisions(ctx, u) } + + return s.FetchV3Decisions(ctx, u) } func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStreamOpts) (*modelscapi.GetDecisionsStreamResponse, *Response, error) { @@ -260,13 +259,13 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream return nil, nil, err } - var decisions modelscapi.GetDecisionsStreamResponse - req, err := s.client.NewRequest(http.MethodGet, u, nil) if err != nil { return nil, nil, err } + decisions := modelscapi.GetDecisionsStreamResponse{} + resp, err := s.client.Do(ctx, req, &decisions) if err != nil { return nil, resp, err @@ -292,8 +291,6 @@ func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) { } func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) (*models.DeleteDecisionResponse, *Response, error) { - var deleteDecisionResponse models.DeleteDecisionResponse - params, err := qs.Values(opts) if err != nil { return nil, nil, err @@ -306,6 +303,8 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) return nil, nil, err } + deleteDecisionResponse := models.DeleteDecisionResponse{} + resp, err := s.client.Do(ctx, req, &deleteDecisionResponse) if err != nil { return nil, resp, err @@ -315,8 +314,6 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) } func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*models.DeleteDecisionResponse, *Response, error) { - var deleteDecisionResponse models.DeleteDecisionResponse - u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decisionID) req, err := s.client.NewRequest(http.MethodDelete, u, nil) @@ -324,6 +321,8 @@ func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*m return nil, nil, err } + deleteDecisionResponse := models.DeleteDecisionResponse{} + resp, err := s.client.Do(ctx, req, &deleteDecisionResponse) if err != nil { return nil, resp, err diff --git a/pkg/apiclient/decisions_sync_service.go b/pkg/apiclient/decisions_sync_service.go index 1aee9b6ca..25e33a8e2 100644 --- a/pkg/apiclient/decisions_sync_service.go +++ b/pkg/apiclient/decisions_sync_service.go @@ -14,8 +14,6 @@ type DecisionDeleteService service // DecisionDeleteService purposely reuses AddSignalsRequestItemDecisions model func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) { - var response interface{} - u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix) req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions) @@ -23,15 +21,17 @@ func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *model return nil, nil, fmt.Errorf("while building request: %w", err) } + var response interface{} + resp, err := d.client.Do(ctx, req, &response) if err != nil { return nil, resp, fmt.Errorf("while performing request: %w", err) } if resp.Response.StatusCode != http.StatusOK { - log.Warnf("Decisions delete response : http %s", resp.Response.Status) + log.Warnf("Decisions delete response: http %s", resp.Response.Status) } else { - log.Debugf("Decisions delete response : http %s", resp.Response.Status) + log.Debugf("Decisions delete response: http %s", resp.Response.Status) } return &response, resp, nil diff --git a/pkg/apiclient/heartbeat.go b/pkg/apiclient/heartbeat.go index 77e0ecc2e..df3afc52f 100644 --- a/pkg/apiclient/heartbeat.go +++ b/pkg/apiclient/heartbeat.go @@ -41,13 +41,13 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) { ok, resp, err := h.Ping(ctx) if err != nil { - log.Errorf("heartbeat error : %s", err) + log.Errorf("heartbeat error: %s", err) continue } resp.Response.Body.Close() if resp.Response.StatusCode != http.StatusOK { - log.Errorf("heartbeat unexpected return code : %d", resp.Response.StatusCode) + log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode) continue } if !ok { diff --git a/pkg/apiclient/metrics.go b/pkg/apiclient/metrics.go index a82273007..7f8d095a2 100644 --- a/pkg/apiclient/metrics.go +++ b/pkg/apiclient/metrics.go @@ -11,8 +11,6 @@ import ( type MetricsService service func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (interface{}, *Response, error) { - var response interface{} - u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix) req, err := s.client.NewRequest(http.MethodPost, u, &metrics) @@ -20,6 +18,8 @@ func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (inte return nil, nil, err } + var response interface{} + resp, err := s.client.Do(ctx, req, &response) if err != nil { return nil, resp, err diff --git a/pkg/apiclient/signal.go b/pkg/apiclient/signal.go index 94c02f080..613ce70bb 100644 --- a/pkg/apiclient/signal.go +++ b/pkg/apiclient/signal.go @@ -13,8 +13,6 @@ import ( type SignalService service func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsRequest) (interface{}, *Response, error) { - var response interface{} - u := fmt.Sprintf("%s/signals", s.client.URLPrefix) req, err := s.client.NewRequest(http.MethodPost, u, &signals) @@ -22,6 +20,8 @@ func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsReque return nil, nil, fmt.Errorf("while building request: %w", err) } + var response interface{} + resp, err := s.client.Do(ctx, req, &response) if err != nil { return nil, resp, fmt.Errorf("while performing request: %w", err) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index dcf12929a..961a9b5ac 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -2,21 +2,21 @@ package apiserver import ( "context" + "errors" "fmt" "math/rand" "net" "net/http" "net/url" + "slices" "strconv" "strings" "sync" "time" "github.com/go-openapi/strfmt" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" - "slices" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/trace" @@ -652,12 +652,13 @@ func (a *apic) PullTop(forcePull bool) error { addCounters, deleteCounters := makeAddAndDeleteCounters() // process deleted decisions - if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters); err != nil { + nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters) + if err != nil { return err - } else { - log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) } + log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted) + if len(data.New) == 0 { log.Infof("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)") return nil diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 638ac2c65..58caeb068 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -2,6 +2,7 @@ package apiserver import ( "context" + "errors" "fmt" "io" "net" @@ -13,7 +14,6 @@ import ( "github.com/gin-gonic/gin" "github.com/go-co-op/gocron" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" "gopkg.in/tomb.v2" @@ -382,7 +382,9 @@ func (s *APIServer) listenAndServeURL(apiReady chan bool) { if s.TLS.KeyFilePath == "" { serverError <- errors.New("missing TLS key file") return - } else if s.TLS.CertFilePath == "" { + } + + if s.TLS.CertFilePath == "" { serverError <- errors.New("missing TLS cert file") return } diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 424c20af6..e7d106d72 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -59,8 +59,10 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { } for _, eventItem := range alert.Edges.Events { - var Metas models.Meta timestamp := eventItem.Time.String() + + var Metas models.Meta + if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil { log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err) } @@ -162,6 +164,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { if alert.Source.Scope != nil { *alert.Source.Scope = normalizeScope(*alert.Source.Scope) } + for _, decision := range alert.Decisions { if decision.Scope != nil { *decision.Scope = normalizeScope(*decision.Scope) @@ -183,30 +186,38 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { _, matched, err := profile.EvaluateProfile(alert) if err != nil { profile.Logger.Warningf("error while evaluating profile %s : %v", profile.Cfg.Name, err) + continue } + if !matched { continue } + c.sendAlertToPluginChannel(alert, uint(pIdx)) + if profile.Cfg.OnSuccess == "break" { break } } + decision := alert.Decisions[0] if decision.Origin != nil && *decision.Origin == types.CscliImportOrigin { stopFlush = true } + continue } for pIdx, profile := range c.Profiles { profileDecisions, matched, err := profile.EvaluateProfile(alert) forceBreak := false + if err != nil { switch profile.Cfg.OnError { case "apply": profile.Logger.Warningf("applying profile %s despite error: %s", profile.Cfg.Name, err) + matched = true case "continue": profile.Logger.Warningf("skipping %s profile due to error: %s", profile.Cfg.Name, err) @@ -219,18 +230,23 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { return } } + if !matched { continue } + for _, decision := range profileDecisions { decision.UUID = uuid.NewString() } - //generate uuid here for alert + + // generate uuid here for alert if len(alert.Decisions) == 0 { // non manual decision alert.Decisions = append(alert.Decisions, profileDecisions...) } + profileAlert := *alert c.sendAlertToPluginChannel(&profileAlert, uint(pIdx)) + if profile.Cfg.OnSuccess == "break" || forceBreak { break } @@ -275,6 +291,7 @@ func (c *Controller) FindAlerts(gctx *gin.Context) { gctx.String(http.StatusOK, "") return } + gctx.JSON(http.StatusOK, data) } @@ -282,21 +299,25 @@ func (c *Controller) FindAlerts(gctx *gin.Context) { func (c *Controller) FindAlertByID(gctx *gin.Context) { alertIDStr := gctx.Param("alert_id") alertID, err := strconv.Atoi(alertIDStr) + if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return } + result, err := c.DBClient.GetAlertByID(alertID) if err != nil { c.HandleDBErrors(gctx, err) return } + data := FormatOneAlert(result) if gctx.Request.Method == http.MethodHead { gctx.String(http.StatusOK, "") return } + gctx.JSON(http.StatusOK, data) } @@ -316,15 +337,14 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) { gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"}) return } + err = c.DBClient.DeleteAlertByID(decisionID) if err != nil { c.HandleDBErrors(gctx, err) return } - deleteAlertResp := models.DeleteAlertsResponse{ - NbDeleted: "1", - } + deleteAlertResp := models.DeleteAlertsResponse{NbDeleted: "1"} gctx.JSON(http.StatusOK, deleteAlertResp) } @@ -336,15 +356,17 @@ func (c *Controller) DeleteAlerts(gctx *gin.Context) { gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)}) return } - var err error + nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) return } + deleteAlertsResp := models.DeleteAlertsResponse{ NbDeleted: strconv.Itoa(nbDeleted), } + gctx.JSON(http.StatusOK, deleteAlertsResp) } @@ -355,5 +377,6 @@ func networksContainIP(networks []net.IPNet, ip string) bool { return true } } + return false } diff --git a/pkg/apiserver/controllers/v1/controller.go b/pkg/apiserver/controllers/v1/controller.go index 60da83d7d..ad76ad766 100644 --- a/pkg/apiserver/controllers/v1/controller.go +++ b/pkg/apiserver/controllers/v1/controller.go @@ -61,8 +61,10 @@ func New(cfg *ControllerV1Config) (*Controller, error) { TrustedIPs: cfg.TrustedIPs, } v1.Middlewares, err = middlewares.NewMiddlewares(cfg.DbClient) + if err != nil { return v1, err } + return v1, nil } diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 534870484..9acfc1f2e 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -33,6 +33,7 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision { } results = append(results, &decision) } + return results } @@ -44,12 +45,14 @@ func (c *Controller) GetDecision(gctx *gin.Context) { bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) + return } data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) + return } @@ -64,6 +67,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) { if gctx.Request.Method == http.MethodHead { gctx.String(http.StatusOK, "") + return } @@ -77,20 +81,22 @@ func (c *Controller) GetDecision(gctx *gin.Context) { } func (c *Controller) DeleteDecisionById(gctx *gin.Context) { - var err error - decisionIDStr := gctx.Param("decision_id") + decisionID, err := strconv.Atoi(decisionIDStr) if err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": "decision_id must be valid integer"}) + return } nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID) if err != nil { c.HandleDBErrors(gctx, err) + return } - //transform deleted decisions to be sendable to capi + + // transform deleted decisions to be sendable to capi deletedDecisions := FormatDecisions(deletedFromDB) if c.DecisionDeleteChan != nil { @@ -105,13 +111,14 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) { } func (c *Controller) DeleteDecisions(gctx *gin.Context) { - var err error nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionsWithFilter(gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) + return } - //transform deleted decisions to be sendable to capi + + // transform deleted decisions to be sendable to capi deletedDecisions := FormatDecisions(deletedFromDB) if c.DecisionDeleteChan != nil { @@ -121,6 +128,7 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) { deleteDecisionResp := models.DeleteDecisionResponse{ NbDeleted: nbDeleted, } + gctx.JSON(http.StatusOK, deleteDecisionResp) } @@ -147,6 +155,7 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) + if needComma { //respBuffer.Write([]byte(",")) gctx.Writer.Write([]byte(",")) @@ -158,6 +167,7 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun _, err := gctx.Writer.Write(decisionJSON) if err != nil { gctx.Writer.Flush() + return err } //respBuffer.Reset() @@ -166,9 +176,11 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) if len(data) < limit { gctx.Writer.Flush() + break } } + return nil } @@ -195,6 +207,7 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul results := FormatDecisions(data) for _, decision := range results { decisionJSON, _ := json.Marshal(decision) + if needComma { //respBuffer.Write([]byte(",")) gctx.Writer.Write([]byte(",")) @@ -206,6 +219,7 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul _, err := gctx.Writer.Write(decisionJSON) if err != nil { gctx.Writer.Flush() + return err } //respBuffer.Reset() @@ -214,9 +228,11 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId) if len(data) < limit { gctx.Writer.Flush() + break } } + return nil } @@ -230,14 +246,13 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B // if the blocker just started, return all decisions if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" { - //Active decisions - + // Active decisions err := writeStartupDecisions(gctx, filters, c.DBClient.QueryAllDecisionsWithFilters) - if err != nil { log.Errorf("failed sending new decisions for startup: %v", err) gctx.Writer.Write([]byte(`], "deleted": []}`)) gctx.Writer.Flush() + return err } @@ -248,6 +263,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B log.Errorf("failed sending expired decisions for startup: %v", err) gctx.Writer.Write([]byte(`]}`)) gctx.Writer.Flush() + return err } @@ -259,6 +275,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B log.Errorf("failed sending new decisions for delta: %v", err) gctx.Writer.Write([]byte(`], "deleted": []}`)) gctx.Writer.Flush() + return err } @@ -270,18 +287,21 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B log.Errorf("failed sending expired decisions for delta: %v", err) gctx.Writer.Write([]byte(`]}`)) gctx.Writer.Flush() + return err } gctx.Writer.Write([]byte(`]}`)) gctx.Writer.Flush() } + return nil } func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error { var data []*ent.Decision var err error + ret := make(map[string][]*models.Decision, 0) ret["new"] = []*models.Decision{} ret["deleted"] = []*models.Decision{} @@ -292,6 +312,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if err != nil { log.Errorf("failed querying decisions: %v", err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } //data = KeepLongestDecision(data) @@ -302,11 +323,14 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } + ret["deleted"] = FormatDecisions(data) gctx.JSON(http.StatusOK, ret) + return nil } } @@ -316,6 +340,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en if err != nil { log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) + return err } //data = KeepLongestDecision(data) diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index bf6fd5781..b19b450f0 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -8,7 +8,6 @@ import ( ) func (c *Controller) HeartBeat(gctx *gin.Context) { - claims := jwt.ExtractClaims(gctx) // TBD: use defined rather than hardcoded key to find back owner machineID := claims["id"].(string) @@ -17,5 +16,6 @@ func (c *Controller) HeartBeat(gctx *gin.Context) { c.HandleDBErrors(gctx, err) return } + gctx.Status(http.StatusOK) } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index 55f79d0c9..84a6ef258 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -11,19 +11,19 @@ import ( ) func (c *Controller) CreateMachine(gctx *gin.Context) { - var err error var input models.WatcherRegistrationRequest - if err = gctx.ShouldBindJSON(&input); err != nil { + + if err := gctx.ShouldBindJSON(&input); err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } - if err = input.Validate(strfmt.Default); err != nil { + + if err := input.Validate(strfmt.Default); err != nil { c.HandleDBErrors(gctx, err) return } - _, err = c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType) - if err != nil { + if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType); err != nil { c.HandleDBErrors(gctx, err) return } diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index 676cc31ea..b1d95dd67 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -93,6 +93,7 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc { "method": c.Request.Method}).Inc() } } + c.Next() } } @@ -106,6 +107,7 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc { "route": c.Request.URL.Path, "method": c.Request.Method}).Inc() } + c.Next() } } @@ -117,6 +119,7 @@ func PrometheusMiddleware() gin.HandlerFunc { "route": c.Request.URL.Path, "method": c.Request.Method}).Inc() c.Next() + elapsed := time.Since(startTime) LapiResponseTime.With(prometheus.Labels{"method": c.Request.Method, "endpoint": c.Request.URL.Path}).Observe(elapsed.Seconds()) } diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go index aaa17ca51..6afd00513 100644 --- a/pkg/apiserver/controllers/v1/utils.go +++ b/pkg/apiserver/controllers/v1/utils.go @@ -9,9 +9,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent" ) -var ( - bouncerContextKey = "bouncer_info" -) +const bouncerContextKey = "bouncer_info" func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { bouncerInterface, exist := ctx.Get(bouncerContextKey) diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 7e4df875c..682f6b638 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -34,7 +34,9 @@ func GenerateAPIKey(n int) (string, error) { if _, err := rand.Read(bytes); err != nil { return "", err } + encoded := base64.StdEncoding.EncodeToString(bytes) + // the '=' can cause issues on some bouncers return strings.TrimRight(encoded, "="), nil } @@ -67,6 +69,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Errorf("invalid client certificate: %s", err) return nil } + if err != nil { logger.Error(err) return nil @@ -88,7 +91,9 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Errorf("error generating mock api key: %s", err) return nil } + logger.Infof("Creating bouncer %s", bouncerName) + bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) if err != nil { logger.Errorf("while creating bouncer db entry: %s", err) @@ -103,6 +108,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Errorf("bouncer isn't allowed to auth by TLS") return nil } + return bouncer } @@ -112,6 +118,7 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { logger.Errorf("API key not found") return nil } + hashStr := HashSHA512(val[0]) bouncer, err := a.DbClient.SelectBouncer(hashStr) @@ -162,16 +169,19 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return } } if bouncer.IPAddress != c.ClientIP() && bouncer.IPAddress != "" { log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress) + if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return } } @@ -187,6 +197,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { logger.Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() + return } } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 8797761a4..ef863a7a2 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -36,14 +36,15 @@ func PayloadFunc(data interface{}) jwt.MapClaims { identityKey: &value.MachineID, } } + return jwt.MapClaims{} } func IdentityHandler(c *gin.Context) interface{} { claims := jwt.ExtractClaims(c) - machineId := claims[identityKey].(string) + machineID := claims[identityKey].(string) return &models.WatcherAuthRequest{ - MachineID: &machineId, + MachineID: &machineID, } } @@ -67,6 +68,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { log.Error(err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() + return nil, fmt.Errorf("while trying to validate client cert: %w", err) } @@ -77,6 +79,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) { } ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + ret.clientMachine, err = j.DbClient.Ent.Machine.Query(). Where(machine.MachineId(ret.machineID)). First(j.DbClient.CTX)