From 23968e472dd8c32d4a5b949ef4f4428ca6050d30 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Mon, 4 Dec 2023 23:06:01 +0100 Subject: [PATCH] Refact bouncer auth (#2456) Co-authored-by: blotus --- cmd/crowdsec-cli/machines.go | 4 +- pkg/apiserver/apiserver.go | 12 +- pkg/apiserver/apiserver_test.go | 9 +- pkg/apiserver/middlewares/v1/api_key.go | 213 +++++++++----------- pkg/apiserver/middlewares/v1/middlewares.go | 2 +- pkg/database/bouncers.go | 6 +- pkg/database/machines.go | 11 +- 7 files changed, 112 insertions(+), 145 deletions(-) diff --git a/cmd/crowdsec-cli/machines.go b/cmd/crowdsec-cli/machines.go index b7ceed254..72e6579cc 100644 --- a/cmd/crowdsec-cli/machines.go +++ b/cmd/crowdsec-cli/machines.go @@ -30,9 +30,7 @@ import ( "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" ) -var ( - passwordLength = 64 -) +const passwordLength = 64 func generatePassword(length int) string { upper := "ABCDEFGHIJKLMNOPQRSTUVWXY" diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 070013298..cfeb13d27 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -106,13 +106,13 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { var flushScheduler *gocron.Scheduler dbClient, err := database.NewClient(config.DbConfig) if err != nil { - return &APIServer{}, fmt.Errorf("unable to init database client: %w", err) + return nil, fmt.Errorf("unable to init database client: %w", err) } if config.DbConfig.Flush != nil { flushScheduler, err = dbClient.StartFlushScheduler(config.DbConfig.Flush) if err != nil { - return &APIServer{}, err + return nil, err } } @@ -129,7 +129,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { if config.TrustedProxies != nil && config.UseForwardedForHeaders { if err := router.SetTrustedProxies(*config.TrustedProxies); err != nil { - return &APIServer{}, fmt.Errorf("while setting trusted_proxies: %w", err) + return nil, fmt.Errorf("while setting trusted_proxies: %w", err) } router.ForwardedByClientIP = true } else { @@ -215,7 +215,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { log.Printf("Loading CAPI manager") apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) if err != nil { - return &APIServer{}, err + return nil, err } log.Infof("CAPI manager configured successfully") isMachineEnrolled = isEnrolled(apiClient.apiClient) @@ -225,7 +225,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { log.Infof("Machine is enrolled in the console, Loading PAPI Client") papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel) if err != nil { - return &APIServer{}, err + return nil, err } controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel } else { @@ -241,7 +241,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { if trustedIPs, err := config.GetTrustedIPs(); err == nil { controller.TrustedIPs = trustedIPs } else { - return &APIServer{}, err + return nil, err } return &APIServer{ diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 435c9601a..6150c351b 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/version" middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" @@ -295,8 +296,8 @@ func TestWithWrongDBConfig(t *testing.T) { config.API.Server.DbConfig.Type = "test" apiServer, err := NewServer(config.API.Server) - assert.Equal(t, apiServer, &APIServer{}) - assert.Equal(t, "unable to init database client: unknown database type 'test'", err.Error()) + cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'") + assert.Nil(t, apiServer) } func TestWithWrongFlushConfig(t *testing.T) { @@ -305,8 +306,8 @@ func TestWithWrongFlushConfig(t *testing.T) { config.API.Server.DbConfig.Flush.MaxItems = &maxItems apiServer, err := NewServer(config.API.Server) - assert.Equal(t, apiServer, &APIServer{}) - assert.Equal(t, "max_items can't be zero or negative number", err.Error()) + cstest.RequireErrorContains(t, err, "max_items can't be zero or negative number") + assert.Nil(t, apiServer) } func TestUnknownPath(t *testing.T) { diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 207f35fc4..1481a0145 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -55,132 +55,110 @@ func HashSHA512(str string) string { return hashStr } +func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer { + if a.TlsAuth == nil { + logger.Error("TLS Auth is not configured but client presented a certificate") + return nil + } + + validCert, extractedCN, err := a.TlsAuth.ValidateCert(c) + if !validCert { + logger.Errorf("invalid client certificate: %s", err) + return nil + } + if err != nil { + logger.Error(err) + return nil + } + + logger = logger.WithFields(log.Fields{ + "cn": extractedCN, + }) + + bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + bouncer, err := a.DbClient.SelectBouncerByName(bouncerName) + + //This is likely not the proper way, but isNotFound does not seem to work + if err != nil && strings.Contains(err.Error(), "bouncer not found") { + //Because we have a valid cert, automatically create the bouncer in the database if it does not exist + //Set a random API key, but it will never be used + apiKey, err := GenerateAPIKey(dummyAPIKeySize) + if err != nil { + logger.Errorf("error generating mock api key: %s", err) + return nil + } + logger.Infof("Creating bouncer %s", bouncerName) + bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + if err != nil { + logger.Errorf("while creating bouncer db entry: %s", err) + return nil + } + } else if err != nil { + //error while selecting bouncer + logger.Errorf("while selecting bouncers: %s", err) + return nil + } else if bouncer.AuthType != types.TlsAuthType { + //bouncer was found in DB + logger.Errorf("bouncer isn't allowed to auth by TLS") + return nil + } + return bouncer +} + +func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer { + val, ok := c.Request.Header[APIKeyHeader] + if !ok { + logger.Errorf("API key not found") + return nil + } + hashStr := HashSHA512(val[0]) + + bouncer, err := a.DbClient.SelectBouncer(hashStr) + if err != nil { + logger.Errorf("while fetching bouncer info: %s", err) + return nil + } + + if bouncer.AuthType != types.ApiKeyAuthType { + logger.Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType) + return nil + } + + return bouncer +} + func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return func(c *gin.Context) { var bouncer *ent.Bouncer - var err error + + logger := log.WithFields(log.Fields{ + "ip": c.ClientIP(), + }) if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { - if a.TlsAuth == nil { - log.WithField("ip", c.ClientIP()).Error("TLS Auth is not configured but client presented a certificate") - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - validCert, extractedCN, err := a.TlsAuth.ValidateCert(c) - if !validCert { - log.WithField("ip", c.ClientIP()).Errorf("invalid client certificate: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - if err != nil { - log.WithField("ip", c.ClientIP()).Error(err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) - bouncer, err = a.DbClient.SelectBouncerByName(bouncerName) - //This is likely not the proper way, but isNotFound does not seem to work - if err != nil && strings.Contains(err.Error(), "bouncer not found") { - //Because we have a valid cert, automatically create the bouncer in the database if it does not exist - //Set a random API key, but it will never be used - apiKey, err := GenerateAPIKey(dummyAPIKeySize) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("error generating mock api key: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Infof("Creating bouncer %s", bouncerName) - bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("creating bouncer db entry : %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - } else if err != nil { - //error while selecting bouncer - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("while selecting bouncers: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } else if bouncer.AuthType != types.TlsAuthType { - //bouncer was found in DB - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "cn": extractedCN, - }).Errorf("bouncer isn't allowed to auth by TLS") - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } + bouncer = a.authTLS(c, logger) } else { - //API Key Authentication - val, ok := c.Request.Header[APIKeyHeader] - if !ok { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - }).Errorf("API key not found") - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - hashStr := HashSHA512(val[0]) - bouncer, err = a.DbClient.SelectBouncer(hashStr) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - }).Errorf("while fetching bouncer info: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } - if bouncer.AuthType != types.ApiKeyAuthType { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - }).Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } + bouncer = a.authPlain(c, logger) } if bouncer == nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - }).Errorf("bouncer not found") c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() return } - //maybe we want to store the whole bouncer object in the context instead, this would avoid another db query - //in StreamDecision + logger = logger.WithFields(log.Fields{ + "name": bouncer.Name, + }) + + // maybe we want to store the whole bouncer object in the context instead, this would avoid another db query + // in StreamDecision c.Set("BOUNCER_NAME", bouncer.Name) c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey) if bouncer.IPAddress == "" { - err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "name": bouncer.Name, - }).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) + if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { + logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() return @@ -189,12 +167,8 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { if bouncer.IPAddress != c.ClientIP() && bouncer.IPAddress != "" { log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress) - err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) - if err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "name": bouncer.Name, - }).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) + if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil { + logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() return @@ -202,21 +176,14 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { } useragent := strings.Split(c.Request.UserAgent(), "/") - if len(useragent) != 2 { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "name": bouncer.Name, - }).Warningf("bad user agent '%s'", c.Request.UserAgent()) + logger.Warningf("bad user agent '%s'", c.Request.UserAgent()) useragent = []string{c.Request.UserAgent(), "N/A"} } if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { - log.WithFields(log.Fields{ - "ip": c.ClientIP(), - "name": bouncer.Name, - }).Errorf("failed to update bouncer version and type: %s", err) + logger.Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() return diff --git a/pkg/apiserver/middlewares/v1/middlewares.go b/pkg/apiserver/middlewares/v1/middlewares.go index 26879bd8e..ef2d93b92 100644 --- a/pkg/apiserver/middlewares/v1/middlewares.go +++ b/pkg/apiserver/middlewares/v1/middlewares.go @@ -14,7 +14,7 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) { ret.JWT, err = NewJWT(dbClient) if err != nil { - return &Middlewares{}, err + return nil, err } ret.APIKey = NewAPIKey(dbClient) diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index 804337ecc..496b9b6cc 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -13,7 +13,7 @@ import ( func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX) if err != nil { - return &ent.Bouncer{}, errors.Wrapf(QueryFail, "select bouncer: %s", err) + return nil, err } return result, nil @@ -22,7 +22,7 @@ func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX) if err != nil { - return &ent.Bouncer{}, errors.Wrapf(QueryFail, "select bouncer: %s", err) + return nil, err } return result, nil @@ -31,7 +31,7 @@ func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { result, err := c.Ent.Bouncer.Query().All(c.CTX) if err != nil { - return []*ent.Bouncer{}, errors.Wrapf(QueryFail, "listing bouncer: %s", err) + return nil, errors.Wrapf(QueryFail, "listing bouncers: %s", err) } return result, nil } diff --git a/pkg/database/machines.go b/pkg/database/machines.go index 98992d478..b9834e57e 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -19,8 +19,8 @@ const CapiListsMachineID = types.ListOrigin func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { - c.Log.Warningf("CreateMachine : %s", err) - return nil, errors.Wrap(HashError, "") + c.Log.Warningf("CreateMachine: %s", err) + return nil, HashError } machineExist, err := c.Ent.Machine. @@ -78,7 +78,7 @@ func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { func (c *Client) ListMachines() ([]*ent.Machine, error) { machines, err := c.Ent.Machine.Query().All(c.CTX) if err != nil { - return []*ent.Machine{}, errors.Wrapf(QueryFail, "listing machines: %s", err) + return nil, errors.Wrapf(QueryFail, "listing machines: %s", err) } return machines, nil } @@ -101,7 +101,7 @@ func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) { machines, err = c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) if err != nil { c.Log.Warningf("QueryPendingMachine : %s", err) - return []*ent.Machine{}, errors.Wrapf(QueryFail, "querying pending machines: %s", err) + return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err) } return machines, nil } @@ -190,12 +190,13 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) { return true, nil } if len(exist) > 1 { - return false, fmt.Errorf("More than one item with the same machineID in database") + return false, fmt.Errorf("more than one item with the same machineID in database") } return false, nil } + func (c *Client) QueryLastValidatedHeartbeatLT(t time.Time) ([]*ent.Machine, error) { return c.Ent.Machine.Query().Where(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)).All(c.CTX) }