refact pkg/apiserver (auth helpers) (#2856)

This commit is contained in:
mmetc 2024-02-23 14:03:50 +01:00 committed by GitHub
parent e34af358d7
commit 4bf640c6e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 50 additions and 45 deletions

View file

@ -9,7 +9,6 @@ import (
"strings" "strings"
"time" "time"
jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt" "github.com/go-openapi/strfmt"
"github.com/google/uuid" "github.com/google/uuid"
@ -143,9 +142,7 @@ func normalizeScope(scope string) string {
func (c *Controller) CreateAlert(gctx *gin.Context) { func (c *Controller) CreateAlert(gctx *gin.Context) {
var input models.AddAlertsRequest var input models.AddAlertsRequest
claims := jwt.ExtractClaims(gctx) machineID, _ := getMachineIDFromContext(gctx)
// TBD: use defined rather than hardcoded key to find back owner
machineID := claims["id"].(string)
if err := gctx.ShouldBindJSON(&input); err != nil { if err := gctx.ShouldBindJSON(&input); err != nil {
gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})

View file

@ -3,14 +3,11 @@ package v1
import ( import (
"net/http" "net/http"
jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func (c *Controller) HeartBeat(gctx *gin.Context) { func (c *Controller) HeartBeat(gctx *gin.Context) {
claims := jwt.ExtractClaims(gctx) machineID, _ := getMachineIDFromContext(gctx)
// TBD: use defined rather than hardcoded key to find back owner
machineID := claims["id"].(string)
if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil { if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil {
c.HandleDBErrors(gctx, err) c.HandleDBErrors(gctx, err)

View file

@ -3,7 +3,6 @@ package v1
import ( import (
"time" "time"
jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
) )
@ -66,32 +65,29 @@ var LapiResponseTime = prometheus.NewHistogramVec(
[]string{"endpoint", "method"}) []string{"endpoint", "method"})
func PrometheusBouncersHasEmptyDecision(c *gin.Context) { func PrometheusBouncersHasEmptyDecision(c *gin.Context) {
name, ok := c.Get("BOUNCER_NAME") bouncer, _ := getBouncerFromContext(c)
if ok { if bouncer != nil {
LapiNilDecisions.With(prometheus.Labels{ LapiNilDecisions.With(prometheus.Labels{
"bouncer": name.(string)}).Inc() "bouncer": bouncer.Name}).Inc()
} }
} }
func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) { func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) {
name, ok := c.Get("BOUNCER_NAME") bouncer, _ := getBouncerFromContext(c)
if ok { if bouncer != nil {
LapiNonNilDecisions.With(prometheus.Labels{ LapiNonNilDecisions.With(prometheus.Labels{
"bouncer": name.(string)}).Inc() "bouncer": bouncer.Name}).Inc()
} }
} }
func PrometheusMachinesMiddleware() gin.HandlerFunc { func PrometheusMachinesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
claims := jwt.ExtractClaims(c) machineID, _ := getMachineIDFromContext(c)
if claims != nil { if machineID != "" {
if rawID, ok := claims["id"]; ok { LapiMachineHits.With(prometheus.Labels{
machineID := rawID.(string) "machine": machineID,
LapiMachineHits.With(prometheus.Labels{ "route": c.Request.URL.Path,
"machine": machineID, "method": c.Request.Method}).Inc()
"route": c.Request.URL.Path,
"method": c.Request.Method}).Inc()
}
} }
c.Next() c.Next()
@ -100,10 +96,10 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc {
func PrometheusBouncersMiddleware() gin.HandlerFunc { func PrometheusBouncersMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
name, ok := c.Get("BOUNCER_NAME") bouncer, _ := getBouncerFromContext(c)
if ok { if bouncer != nil {
LapiBouncerHits.With(prometheus.Labels{ LapiBouncerHits.With(prometheus.Labels{
"bouncer": name.(string), "bouncer": bouncer.Name,
"route": c.Request.URL.Path, "route": c.Request.URL.Path,
"method": c.Request.Method}).Inc() "method": c.Request.Method}).Inc()
} }

View file

@ -1,30 +1,50 @@
package v1 package v1
import ( import (
"fmt" "errors"
"net/http" "net/http"
jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
"github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent"
) )
const bouncerContextKey = "bouncer_info"
func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) { func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) {
bouncerInterface, exist := ctx.Get(bouncerContextKey) bouncerInterface, exist := ctx.Get(middlewares.BouncerContextKey)
if !exist { if !exist {
return nil, fmt.Errorf("bouncer not found") return nil, errors.New("bouncer not found")
} }
bouncerInfo, ok := bouncerInterface.(*ent.Bouncer) bouncerInfo, ok := bouncerInterface.(*ent.Bouncer)
if !ok { if !ok {
return nil, fmt.Errorf("bouncer not found") return nil, errors.New("bouncer not found")
} }
return bouncerInfo, nil return bouncerInfo, nil
} }
func getMachineIDFromContext(ctx *gin.Context) (string, error) {
claims := jwt.ExtractClaims(ctx)
if claims == nil {
return "", errors.New("failed to extract claims")
}
rawID, ok := claims[middlewares.MachineIDKey]
if !ok {
return "", errors.New("MachineID not found in claims")
}
id, ok := rawID.(string)
if !ok {
// should never happen
return "", errors.New("failed to cast machineID to string")
}
return id, nil
}
func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc { func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc {
return func(gctx *gin.Context) { return func(gctx *gin.Context) {
incomingIP := gctx.ClientIP() incomingIP := gctx.ClientIP()

View file

@ -18,9 +18,9 @@ import (
const ( const (
APIKeyHeader = "X-Api-Key" APIKeyHeader = "X-Api-Key"
bouncerContextKey = "bouncer_info" BouncerContextKey = "bouncer_info"
// max allowed by bcrypt 72 = 54 bytes in base64
dummyAPIKeySize = 54 dummyAPIKeySize = 54
// max allowed by bcrypt 72 = 54 bytes in base64
) )
type APIKey struct { type APIKey struct {
@ -159,11 +159,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
"name": bouncer.Name, "name": bouncer.Name,
}) })
// maybe we want to store the whole bouncer object in the context instead, this would avoid another db query
// in StreamDecision
c.Set("BOUNCER_NAME", bouncer.Name)
c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey)
if bouncer.IPAddress == "" { if bouncer.IPAddress == "" {
if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
@ -203,7 +198,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
} }
} }
c.Set(bouncerContextKey, bouncer) c.Set(BouncerContextKey, bouncer)
c.Next() c.Next()
} }
} }

View file

@ -22,7 +22,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
) )
var identityKey = "id" const MachineIDKey = "id"
type JWT struct { type JWT struct {
Middleware *jwt.GinJWTMiddleware Middleware *jwt.GinJWTMiddleware
@ -33,7 +33,7 @@ type JWT struct {
func PayloadFunc(data interface{}) jwt.MapClaims { func PayloadFunc(data interface{}) jwt.MapClaims {
if value, ok := data.(*models.WatcherAuthRequest); ok { if value, ok := data.(*models.WatcherAuthRequest); ok {
return jwt.MapClaims{ return jwt.MapClaims{
identityKey: &value.MachineID, MachineIDKey: &value.MachineID,
} }
} }
@ -42,7 +42,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims {
func IdentityHandler(c *gin.Context) interface{} { func IdentityHandler(c *gin.Context) interface{} {
claims := jwt.ExtractClaims(c) claims := jwt.ExtractClaims(c)
machineID := claims[identityKey].(string) machineID := claims[MachineIDKey].(string)
return &models.WatcherAuthRequest{ return &models.WatcherAuthRequest{
MachineID: &machineID, MachineID: &machineID,
@ -307,7 +307,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) {
Key: secret, Key: secret,
Timeout: time.Hour, Timeout: time.Hour,
MaxRefresh: time.Hour, MaxRefresh: time.Hour,
IdentityKey: identityKey, IdentityKey: MachineIDKey,
PayloadFunc: PayloadFunc, PayloadFunc: PayloadFunc,
IdentityHandler: IdentityHandler, IdentityHandler: IdentityHandler,
Authenticator: jwtMiddleware.Authenticator, Authenticator: jwtMiddleware.Authenticator,