diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index 2a1edaebd..68a7370e3 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -1,8 +1,6 @@ package v1 import ( - "crypto/sha512" - "fmt" "net/http" "strconv" "time" @@ -36,6 +34,12 @@ func (c *Controller) GetDecision(gctx *gin.Context) { var results []*models.Decision var data []*ent.Decision + bouncerInfo, err := getBouncerFromContext(gctx) + if err != nil { + gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) + return + } + data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query()) if err != nil { c.HandleDBErrors(gctx, err) @@ -59,6 +63,13 @@ func (c *Controller) GetDecision(gctx *gin.Context) { gctx.String(http.StatusOK, "") return } + + if time.Now().UTC().Sub(bouncerInfo.LastPull) >= time.Minute { + if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil { + log.Errorf("failed to update bouncer last pull: %v", err) + } + } + gctx.JSON(http.StatusOK, results) } @@ -101,25 +112,14 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) { func (c *Controller) StreamDecision(gctx *gin.Context) { var data []*ent.Decision + var err error ret := make(map[string][]*models.Decision, 0) ret["new"] = []*models.Decision{} ret["deleted"] = []*models.Decision{} + streamStartTime := time.Now().UTC() - val := gctx.Request.Header.Get(c.APIKeyHeader) - hashedKey := sha512.New() - hashedKey.Write([]byte(val)) - hashStr := fmt.Sprintf("%x", hashedKey.Sum(nil)) - bouncerInfo, err := c.DBClient.SelectBouncer(hashStr) + bouncerInfo, err := getBouncerFromContext(gctx) if err != nil { - if _, ok := err.(*ent.NotFoundError); ok { - gctx.JSON(http.StatusForbidden, gin.H{"message": err.Error()}) - } else { - gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) - } - return - } - - if bouncerInfo == nil { gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"}) return } @@ -159,7 +159,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { return } - if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil { + if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil { log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return @@ -201,7 +201,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { return } - if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil { + if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil { log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) return diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go new file mode 100644 index 000000000..b7c413d4d --- /dev/null +++ b/pkg/apiserver/controllers/v1/utils.go @@ -0,0 +1,26 @@ +package v1 + +import ( + "fmt" + + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/gin-gonic/gin" +) + +var ( + bouncerContextKey = "bouncer_info" +) + +func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { + bouncerInterface, exist := ctx.Get(bouncerContextKey) + if !exist { + return nil, fmt.Errorf("bouncer not found") + } + + bouncerInfo, ok := bouncerInterface.(*ent.Bouncer) + if !ok { + return nil, fmt.Errorf("bouncer not found") + } + + return bouncerInfo, nil +} diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index 9612eb2e0..0447cbf9a 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -2,6 +2,7 @@ package apiserver import ( "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -313,6 +314,8 @@ func TestStreamDecisionDedup(t *testing.T) { // Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3 lapi.InsertAlertFromFile("./tests/alert_sample.json") + time.Sleep(2 * time.Second) + // Get Stream, we only get one decision (the longest one) w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) decisions, code, err := readDecisionsStreamResp(w) @@ -335,7 +338,6 @@ func TestStreamDecisionDedup(t *testing.T) { assert.Equal(t, code, 200) assert.Equal(t, len(decisions["deleted"]), 0) assert.Equal(t, len(decisions["new"]), 0) - // We delete another decision, yet don't receive it in stream, since there's another decision on same IP w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody) assert.Equal(t, 200, w.Code) diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 8e7dc087b..23267d36f 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -7,7 +7,6 @@ import ( "fmt" "net/http" "strings" - "time" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/gin-gonic/gin" @@ -15,7 +14,8 @@ import ( ) var ( - APIKeyHeader = "X-Api-Key" + APIKeyHeader = "X-Api-Key" + bouncerContextKey = "bouncer_info" ) type APIKey struct { @@ -110,11 +110,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } } - if c.Request.Method != "HEAD" && time.Now().UTC().Sub(bouncer.LastPull) >= time.Minute { - if err := a.DbClient.UpdateBouncerLastPull(time.Now().UTC(), bouncer.ID); err != nil { - log.Errorf("failed to update bouncer last pull: %v", err) - } - } + c.Set(bouncerContextKey, bouncer) c.Next() } diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index d270d5fc7..aabdcc3c3 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -763,9 +763,9 @@ func (c *Client) QueryAlertWithFilter(filter map[string][]string) ([]*ent.Alert, } if sort == "ASC" { - alerts = alerts.Order(ent.Asc(alert.FieldCreatedAt)) + alerts = alerts.Order(ent.Asc(alert.FieldCreatedAt), ent.Asc(alert.FieldID)) } else { - alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt)) + alerts = alerts.Order(ent.Desc(alert.FieldCreatedAt), ent.Desc(alert.FieldID)) } result, err := alerts.Limit(paginationSize).Offset(offset).All(c.CTX)