refact pkg/apiserver (auth helpers) (#2856)
This commit is contained in:
parent
e34af358d7
commit
4bf640c6e8
|
@ -9,7 +9,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
jwt "github.com/appleboy/gin-jwt/v2"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-openapi/strfmt"
|
"github.com/go-openapi/strfmt"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
@ -143,9 +142,7 @@ func normalizeScope(scope string) string {
|
||||||
func (c *Controller) CreateAlert(gctx *gin.Context) {
|
func (c *Controller) CreateAlert(gctx *gin.Context) {
|
||||||
var input models.AddAlertsRequest
|
var input models.AddAlertsRequest
|
||||||
|
|
||||||
claims := jwt.ExtractClaims(gctx)
|
machineID, _ := getMachineIDFromContext(gctx)
|
||||||
// TBD: use defined rather than hardcoded key to find back owner
|
|
||||||
machineID := claims["id"].(string)
|
|
||||||
|
|
||||||
if err := gctx.ShouldBindJSON(&input); err != nil {
|
if err := gctx.ShouldBindJSON(&input); err != nil {
|
||||||
gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
|
||||||
|
|
|
@ -3,14 +3,11 @@ package v1
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
jwt "github.com/appleboy/gin-jwt/v2"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Controller) HeartBeat(gctx *gin.Context) {
|
func (c *Controller) HeartBeat(gctx *gin.Context) {
|
||||||
claims := jwt.ExtractClaims(gctx)
|
machineID, _ := getMachineIDFromContext(gctx)
|
||||||
// TBD: use defined rather than hardcoded key to find back owner
|
|
||||||
machineID := claims["id"].(string)
|
|
||||||
|
|
||||||
if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil {
|
if err := c.DBClient.UpdateMachineLastHeartBeat(machineID); err != nil {
|
||||||
c.HandleDBErrors(gctx, err)
|
c.HandleDBErrors(gctx, err)
|
||||||
|
|
|
@ -3,7 +3,6 @@ package v1
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
jwt "github.com/appleboy/gin-jwt/v2"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
)
|
)
|
||||||
|
@ -66,32 +65,29 @@ var LapiResponseTime = prometheus.NewHistogramVec(
|
||||||
[]string{"endpoint", "method"})
|
[]string{"endpoint", "method"})
|
||||||
|
|
||||||
func PrometheusBouncersHasEmptyDecision(c *gin.Context) {
|
func PrometheusBouncersHasEmptyDecision(c *gin.Context) {
|
||||||
name, ok := c.Get("BOUNCER_NAME")
|
bouncer, _ := getBouncerFromContext(c)
|
||||||
if ok {
|
if bouncer != nil {
|
||||||
LapiNilDecisions.With(prometheus.Labels{
|
LapiNilDecisions.With(prometheus.Labels{
|
||||||
"bouncer": name.(string)}).Inc()
|
"bouncer": bouncer.Name}).Inc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) {
|
func PrometheusBouncersHasNonEmptyDecision(c *gin.Context) {
|
||||||
name, ok := c.Get("BOUNCER_NAME")
|
bouncer, _ := getBouncerFromContext(c)
|
||||||
if ok {
|
if bouncer != nil {
|
||||||
LapiNonNilDecisions.With(prometheus.Labels{
|
LapiNonNilDecisions.With(prometheus.Labels{
|
||||||
"bouncer": name.(string)}).Inc()
|
"bouncer": bouncer.Name}).Inc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrometheusMachinesMiddleware() gin.HandlerFunc {
|
func PrometheusMachinesMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
claims := jwt.ExtractClaims(c)
|
machineID, _ := getMachineIDFromContext(c)
|
||||||
if claims != nil {
|
if machineID != "" {
|
||||||
if rawID, ok := claims["id"]; ok {
|
LapiMachineHits.With(prometheus.Labels{
|
||||||
machineID := rawID.(string)
|
"machine": machineID,
|
||||||
LapiMachineHits.With(prometheus.Labels{
|
"route": c.Request.URL.Path,
|
||||||
"machine": machineID,
|
"method": c.Request.Method}).Inc()
|
||||||
"route": c.Request.URL.Path,
|
|
||||||
"method": c.Request.Method}).Inc()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
|
@ -100,10 +96,10 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc {
|
||||||
|
|
||||||
func PrometheusBouncersMiddleware() gin.HandlerFunc {
|
func PrometheusBouncersMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
name, ok := c.Get("BOUNCER_NAME")
|
bouncer, _ := getBouncerFromContext(c)
|
||||||
if ok {
|
if bouncer != nil {
|
||||||
LapiBouncerHits.With(prometheus.Labels{
|
LapiBouncerHits.With(prometheus.Labels{
|
||||||
"bouncer": name.(string),
|
"bouncer": bouncer.Name,
|
||||||
"route": c.Request.URL.Path,
|
"route": c.Request.URL.Path,
|
||||||
"method": c.Request.Method}).Inc()
|
"method": c.Request.Method}).Inc()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,30 +1,50 @@
|
||||||
package v1
|
package v1
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
jwt "github.com/appleboy/gin-jwt/v2"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
|
||||||
)
|
)
|
||||||
|
|
||||||
const bouncerContextKey = "bouncer_info"
|
|
||||||
|
|
||||||
func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) {
|
func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) {
|
||||||
bouncerInterface, exist := ctx.Get(bouncerContextKey)
|
bouncerInterface, exist := ctx.Get(middlewares.BouncerContextKey)
|
||||||
if !exist {
|
if !exist {
|
||||||
return nil, fmt.Errorf("bouncer not found")
|
return nil, errors.New("bouncer not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
bouncerInfo, ok := bouncerInterface.(*ent.Bouncer)
|
bouncerInfo, ok := bouncerInterface.(*ent.Bouncer)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("bouncer not found")
|
return nil, errors.New("bouncer not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
return bouncerInfo, nil
|
return bouncerInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getMachineIDFromContext(ctx *gin.Context) (string, error) {
|
||||||
|
claims := jwt.ExtractClaims(ctx)
|
||||||
|
if claims == nil {
|
||||||
|
return "", errors.New("failed to extract claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
rawID, ok := claims[middlewares.MachineIDKey]
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("MachineID not found in claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
id, ok := rawID.(string)
|
||||||
|
if !ok {
|
||||||
|
// should never happen
|
||||||
|
return "", errors.New("failed to cast machineID to string")
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc {
|
func (c *Controller) AbortRemoteIf(option bool) gin.HandlerFunc {
|
||||||
return func(gctx *gin.Context) {
|
return func(gctx *gin.Context) {
|
||||||
incomingIP := gctx.ClientIP()
|
incomingIP := gctx.ClientIP()
|
||||||
|
|
|
@ -18,9 +18,9 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
APIKeyHeader = "X-Api-Key"
|
APIKeyHeader = "X-Api-Key"
|
||||||
bouncerContextKey = "bouncer_info"
|
BouncerContextKey = "bouncer_info"
|
||||||
// max allowed by bcrypt 72 = 54 bytes in base64
|
|
||||||
dummyAPIKeySize = 54
|
dummyAPIKeySize = 54
|
||||||
|
// max allowed by bcrypt 72 = 54 bytes in base64
|
||||||
)
|
)
|
||||||
|
|
||||||
type APIKey struct {
|
type APIKey struct {
|
||||||
|
@ -159,11 +159,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
|
||||||
"name": bouncer.Name,
|
"name": bouncer.Name,
|
||||||
})
|
})
|
||||||
|
|
||||||
// maybe we want to store the whole bouncer object in the context instead, this would avoid another db query
|
|
||||||
// in StreamDecision
|
|
||||||
c.Set("BOUNCER_NAME", bouncer.Name)
|
|
||||||
c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey)
|
|
||||||
|
|
||||||
if bouncer.IPAddress == "" {
|
if bouncer.IPAddress == "" {
|
||||||
if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil {
|
if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil {
|
||||||
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
|
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
|
||||||
|
@ -203,7 +198,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Set(bouncerContextKey, bouncer)
|
c.Set(BouncerContextKey, bouncer)
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,7 @@ import (
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
var identityKey = "id"
|
const MachineIDKey = "id"
|
||||||
|
|
||||||
type JWT struct {
|
type JWT struct {
|
||||||
Middleware *jwt.GinJWTMiddleware
|
Middleware *jwt.GinJWTMiddleware
|
||||||
|
@ -33,7 +33,7 @@ type JWT struct {
|
||||||
func PayloadFunc(data interface{}) jwt.MapClaims {
|
func PayloadFunc(data interface{}) jwt.MapClaims {
|
||||||
if value, ok := data.(*models.WatcherAuthRequest); ok {
|
if value, ok := data.(*models.WatcherAuthRequest); ok {
|
||||||
return jwt.MapClaims{
|
return jwt.MapClaims{
|
||||||
identityKey: &value.MachineID,
|
MachineIDKey: &value.MachineID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims {
|
||||||
|
|
||||||
func IdentityHandler(c *gin.Context) interface{} {
|
func IdentityHandler(c *gin.Context) interface{} {
|
||||||
claims := jwt.ExtractClaims(c)
|
claims := jwt.ExtractClaims(c)
|
||||||
machineID := claims[identityKey].(string)
|
machineID := claims[MachineIDKey].(string)
|
||||||
|
|
||||||
return &models.WatcherAuthRequest{
|
return &models.WatcherAuthRequest{
|
||||||
MachineID: &machineID,
|
MachineID: &machineID,
|
||||||
|
@ -307,7 +307,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) {
|
||||||
Key: secret,
|
Key: secret,
|
||||||
Timeout: time.Hour,
|
Timeout: time.Hour,
|
||||||
MaxRefresh: time.Hour,
|
MaxRefresh: time.Hour,
|
||||||
IdentityKey: identityKey,
|
IdentityKey: MachineIDKey,
|
||||||
PayloadFunc: PayloadFunc,
|
PayloadFunc: PayloadFunc,
|
||||||
IdentityHandler: IdentityHandler,
|
IdentityHandler: IdentityHandler,
|
||||||
Authenticator: jwtMiddleware.Authenticator,
|
Authenticator: jwtMiddleware.Authenticator,
|
||||||
|
|
Loading…
Reference in a new issue