diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index e7d106d72..ad183e4ba 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -9,7 +9,6 @@ import ( "strings" "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" "github.com/google/uuid" @@ -143,9 +142,7 @@ func normalizeScope(scope string) string { func (c *Controller) CreateAlert(gctx *gin.Context) { var input models.AddAlertsRequest - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + machineID, _ := getMachineIDFromContext(gctx) if err := gctx.ShouldBindJSON(&input); err != nil { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) diff --git a/pkg/apiserver/controllers/v1/heartbeat.go b/pkg/apiserver/controllers/v1/heartbeat.go index b19b450f0..e1231eaa9 100644 --- a/pkg/apiserver/controllers/v1/heartbeat.go +++ b/pkg/apiserver/controllers/v1/heartbeat.go @@ -3,14 +3,11 @@ package v1 import ( "net/http" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" ) func (c *Controller) HeartBeat(gctx *gin.Context) { - claims := jwt.ExtractClaims(gctx) - // TBD: use defined rather than hardcoded key to find back owner - machineID := claims["id"].(string) + machineID, _ := getMachineIDFromContext(gctx) if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { c.HandleDBErrors(gctx, err) diff --git a/pkg/apiserver/controllers/v1/metrics.go b/pkg/apiserver/controllers/v1/metrics.go index 13ccf9ac9..ddb38512a 100644 --- a/pkg/apiserver/controllers/v1/metrics.go +++ b/pkg/apiserver/controllers/v1/metrics.go @@ -3,7 +3,6 @@ package v1 import ( "time" - jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" ) @@ -66,32 +65,29 @@ var LapiResponseTime = prometheus.NewHistogramVec( []string{"endpoint", "method"}) func PrometheusBouncersHasEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name}).Inc() } } func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiNonNilDecisions.With(prometheus.Labels{ - "bouncer": name.(string)}).Inc() + "bouncer": bouncer.Name}).Inc() } } func PrometheusMachinesMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - claims := jwt.ExtractClaims(c) - if claims != nil { - if rawID, ok := claims["id"]; ok { - machineID := rawID.(string) - LapiMachineHits.With(prometheus.Labels{ - "machine": machineID, - "route": c.Request.URL.Path, - "method": c.Request.Method}).Inc() - } + machineID, _ := getMachineIDFromContext(c) + if machineID != "" { + LapiMachineHits.With(prometheus.Labels{ + "machine": machineID, + "route": c.Request.URL.Path, + "method": c.Request.Method}).Inc() } c.Next() @@ -100,10 +96,10 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc { func PrometheusBouncersMiddleware() gin.HandlerFunc { return func(c *gin.Context) { - name, ok := c.Get("BOUNCER_NAME") - if ok { + bouncer, _ := getBouncerFromContext(c) + if bouncer != nil { LapiBouncerHits.With(prometheus.Labels{ - "bouncer": name.(string), + "bouncer": bouncer.Name, "route": c.Request.URL.Path, "method": c.Request.Method}).Inc() } diff --git a/pkg/apiserver/controllers/v1/utils.go b/pkg/apiserver/controllers/v1/utils.go index 6afd00513..6f14dd920 100644 --- a/pkg/apiserver/controllers/v1/utils.go +++ b/pkg/apiserver/controllers/v1/utils.go @@ -1,30 +1,50 @@ package v1 import ( - "fmt" + "errors" "net/http" + jwt "github.com/appleboy/gin-jwt/v2" "github.com/gin-gonic/gin" + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/database/ent" ) -const bouncerContextKey = "bouncer_info" - func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { - bouncerInterface, exist := ctx.Get(bouncerContextKey) + bouncerInterface, exist := ctx.Get(middlewares.BouncerContextKey) if !exist { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } bouncerInfo, ok := bouncerInterface.(*ent.Bouncer) if !ok { - return nil, fmt.Errorf("bouncer not found") + return nil, errors.New("bouncer not found") } return bouncerInfo, nil } +func getMachineIDFromContext(ctx *gin.Context) (string, error) { + claims := jwt.ExtractClaims(ctx) + if claims == nil { + return "", errors.New("failed to extract claims") + } + + rawID, ok := claims[middlewares.MachineIDKey] + if !ok { + return "", errors.New("MachineID not found in claims") + } + + id, ok := rawID.(string) + if !ok { + // should never happen + return "", errors.New("failed to cast machineID to string") + } + + return id, nil +} + func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc { return func(gctx *gin.Context) { incomingIP := gctx.ClientIP() diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 41ee15b44..4e273371b 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -18,9 +18,9 @@ import ( const ( APIKeyHeader = "X-Api-Key" - bouncerContextKey = "bouncer_info" - // max allowed by bcrypt 72 = 54 bytes in base64 + BouncerContextKey = "bouncer_info" dummyAPIKeySize = 54 + // max allowed by bcrypt 72 = 54 bytes in base64 ) type APIKey struct { @@ -159,11 +159,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { "name": bouncer.Name, }) - // maybe we want to store the whole bouncer object in the context instead, this would avoid another db query - // in StreamDecision - c.Set("BOUNCER_NAME", bouncer.Name) - c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey) - if bouncer.IPAddress == "" { if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) @@ -203,7 +198,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } } - c.Set(bouncerContextKey, bouncer) + c.Set(BouncerContextKey, bouncer) c.Next() } } diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index ed4ad107b..6fe053713 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -22,7 +22,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -var identityKey = "id" +const MachineIDKey = "id" type JWT struct { Middleware *jwt.GinJWTMiddleware @@ -33,7 +33,7 @@ type JWT struct { func PayloadFunc(data interface{}) jwt.MapClaims { if value, ok := data.(*models.WatcherAuthRequest); ok { return jwt.MapClaims{ - identityKey: &value.MachineID, + MachineIDKey: &value.MachineID, } } @@ -42,7 +42,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims { func IdentityHandler(c *gin.Context) interface{} { claims := jwt.ExtractClaims(c) - machineID := claims[identityKey].(string) + machineID := claims[MachineIDKey].(string) return &models.WatcherAuthRequest{ MachineID: &machineID, @@ -307,7 +307,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { Key: secret, Timeout: time.Hour, MaxRefresh: time.Hour, - IdentityKey: identityKey, + IdentityKey: MachineIDKey, PayloadFunc: PayloadFunc, IdentityHandler: IdentityHandler, Authenticator: jwtMiddleware.Authenticator,