Refact bouncer auth (#2456)

Co-authored-by: blotus <sebastien@crowdsec.net>
This commit is contained in:
mmetc 2023-12-04 23:06:01 +01:00 committed by GitHub
parent a5ab73d458
commit 23968e472d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 112 additions and 145 deletions

View file

@ -30,9 +30,7 @@ import (
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
) )
var ( const passwordLength = 64
passwordLength = 64
)
func generatePassword(length int) string { func generatePassword(length int) string {
upper := "ABCDEFGHIJKLMNOPQRSTUVWXY" upper := "ABCDEFGHIJKLMNOPQRSTUVWXY"

View file

@ -106,13 +106,13 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
var flushScheduler *gocron.Scheduler var flushScheduler *gocron.Scheduler
dbClient, err := database.NewClient(config.DbConfig) dbClient, err := database.NewClient(config.DbConfig)
if err != nil { if err != nil {
return &APIServer{}, fmt.Errorf("unable to init database client: %w", err) return nil, fmt.Errorf("unable to init database client: %w", err)
} }
if config.DbConfig.Flush != nil { if config.DbConfig.Flush != nil {
flushScheduler, err = dbClient.StartFlushScheduler(config.DbConfig.Flush) flushScheduler, err = dbClient.StartFlushScheduler(config.DbConfig.Flush)
if err != nil { if err != nil {
return &APIServer{}, err return nil, err
} }
} }
@ -129,7 +129,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
if config.TrustedProxies != nil && config.UseForwardedForHeaders { if config.TrustedProxies != nil && config.UseForwardedForHeaders {
if err := router.SetTrustedProxies(*config.TrustedProxies); err != nil { if err := router.SetTrustedProxies(*config.TrustedProxies); err != nil {
return &APIServer{}, fmt.Errorf("while setting trusted_proxies: %w", err) return nil, fmt.Errorf("while setting trusted_proxies: %w", err)
} }
router.ForwardedByClientIP = true router.ForwardedByClientIP = true
} else { } else {
@ -215,7 +215,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
log.Printf("Loading CAPI manager") log.Printf("Loading CAPI manager")
apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists) apiClient, err = NewAPIC(config.OnlineClient, dbClient, config.ConsoleConfig, config.CapiWhitelists)
if err != nil { if err != nil {
return &APIServer{}, err return nil, err
} }
log.Infof("CAPI manager configured successfully") log.Infof("CAPI manager configured successfully")
isMachineEnrolled = isEnrolled(apiClient.apiClient) isMachineEnrolled = isEnrolled(apiClient.apiClient)
@ -225,7 +225,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
log.Infof("Machine is enrolled in the console, Loading PAPI Client") log.Infof("Machine is enrolled in the console, Loading PAPI Client")
papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel) papiClient, err = NewPAPI(apiClient, dbClient, config.ConsoleConfig, *config.PapiLogLevel)
if err != nil { if err != nil {
return &APIServer{}, err return nil, err
} }
controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel controller.DecisionDeleteChan = papiClient.Channels.DeleteDecisionChannel
} else { } else {
@ -241,7 +241,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
if trustedIPs, err := config.GetTrustedIPs(); err == nil { if trustedIPs, err := config.GetTrustedIPs(); err == nil {
controller.TrustedIPs = trustedIPs controller.TrustedIPs = trustedIPs
} else { } else {
return &APIServer{}, err return nil, err
} }
return &APIServer{ return &APIServer{

View file

@ -11,6 +11,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/go-cs-lib/version"
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
@ -295,8 +296,8 @@ func TestWithWrongDBConfig(t *testing.T) {
config.API.Server.DbConfig.Type = "test" config.API.Server.DbConfig.Type = "test"
apiServer, err := NewServer(config.API.Server) apiServer, err := NewServer(config.API.Server)
assert.Equal(t, apiServer, &APIServer{}) cstest.RequireErrorContains(t, err, "unable to init database client: unknown database type 'test'")
assert.Equal(t, "unable to init database client: unknown database type 'test'", err.Error()) assert.Nil(t, apiServer)
} }
func TestWithWrongFlushConfig(t *testing.T) { func TestWithWrongFlushConfig(t *testing.T) {
@ -305,8 +306,8 @@ func TestWithWrongFlushConfig(t *testing.T) {
config.API.Server.DbConfig.Flush.MaxItems = &maxItems config.API.Server.DbConfig.Flush.MaxItems = &maxItems
apiServer, err := NewServer(config.API.Server) apiServer, err := NewServer(config.API.Server)
assert.Equal(t, apiServer, &APIServer{}) cstest.RequireErrorContains(t, err, "max_items can't be zero or negative number")
assert.Equal(t, "max_items can't be zero or negative number", err.Error()) assert.Nil(t, apiServer)
} }
func TestUnknownPath(t *testing.T) { func TestUnknownPath(t *testing.T) {

View file

@ -55,132 +55,110 @@ func HashSHA512(str string) string {
return hashStr return hashStr
} }
func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
if a.TlsAuth == nil {
logger.Error("TLS Auth is not configured but client presented a certificate")
return nil
}
validCert, extractedCN, err := a.TlsAuth.ValidateCert(c)
if !validCert {
logger.Errorf("invalid client certificate: %s", err)
return nil
}
if err != nil {
logger.Error(err)
return nil
}
logger = logger.WithFields(log.Fields{
"cn": extractedCN,
})
bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
bouncer, err := a.DbClient.SelectBouncerByName(bouncerName)
//This is likely not the proper way, but isNotFound does not seem to work
if err != nil && strings.Contains(err.Error(), "bouncer not found") {
//Because we have a valid cert, automatically create the bouncer in the database if it does not exist
//Set a random API key, but it will never be used
apiKey, err := GenerateAPIKey(dummyAPIKeySize)
if err != nil {
logger.Errorf("error generating mock api key: %s", err)
return nil
}
logger.Infof("Creating bouncer %s", bouncerName)
bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType)
if err != nil {
logger.Errorf("while creating bouncer db entry: %s", err)
return nil
}
} else if err != nil {
//error while selecting bouncer
logger.Errorf("while selecting bouncers: %s", err)
return nil
} else if bouncer.AuthType != types.TlsAuthType {
//bouncer was found in DB
logger.Errorf("bouncer isn't allowed to auth by TLS")
return nil
}
return bouncer
}
func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer {
val, ok := c.Request.Header[APIKeyHeader]
if !ok {
logger.Errorf("API key not found")
return nil
}
hashStr := HashSHA512(val[0])
bouncer, err := a.DbClient.SelectBouncer(hashStr)
if err != nil {
logger.Errorf("while fetching bouncer info: %s", err)
return nil
}
if bouncer.AuthType != types.ApiKeyAuthType {
logger.Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType)
return nil
}
return bouncer
}
func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
var bouncer *ent.Bouncer var bouncer *ent.Bouncer
var err error
logger := log.WithFields(log.Fields{
"ip": c.ClientIP(),
})
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
if a.TlsAuth == nil { bouncer = a.authTLS(c, logger)
log.WithField("ip", c.ClientIP()).Error("TLS Auth is not configured but client presented a certificate")
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
validCert, extractedCN, err := a.TlsAuth.ValidateCert(c)
if !validCert {
log.WithField("ip", c.ClientIP()).Errorf("invalid client certificate: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
if err != nil {
log.WithField("ip", c.ClientIP()).Error(err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
bouncer, err = a.DbClient.SelectBouncerByName(bouncerName)
//This is likely not the proper way, but isNotFound does not seem to work
if err != nil && strings.Contains(err.Error(), "bouncer not found") {
//Because we have a valid cert, automatically create the bouncer in the database if it does not exist
//Set a random API key, but it will never be used
apiKey, err := GenerateAPIKey(dummyAPIKeySize)
if err != nil {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("error generating mock api key: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Infof("Creating bouncer %s", bouncerName)
bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType)
if err != nil {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("creating bouncer db entry : %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
} else if err != nil {
//error while selecting bouncer
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("while selecting bouncers: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
} else if bouncer.AuthType != types.TlsAuthType {
//bouncer was found in DB
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("bouncer isn't allowed to auth by TLS")
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
} else { } else {
//API Key Authentication bouncer = a.authPlain(c, logger)
val, ok := c.Request.Header[APIKeyHeader]
if !ok {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
}).Errorf("API key not found")
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
hashStr := HashSHA512(val[0])
bouncer, err = a.DbClient.SelectBouncer(hashStr)
if err != nil {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
}).Errorf("while fetching bouncer info: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
if bouncer.AuthType != types.ApiKeyAuthType {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
}).Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
} }
if bouncer == nil { if bouncer == nil {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
}).Errorf("bouncer not found")
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
return return
} }
//maybe we want to store the whole bouncer object in the context instead, this would avoid another db query logger = logger.WithFields(log.Fields{
//in StreamDecision "name": bouncer.Name,
})
// maybe we want to store the whole bouncer object in the context instead, this would avoid another db query
// in StreamDecision
c.Set("BOUNCER_NAME", bouncer.Name) c.Set("BOUNCER_NAME", bouncer.Name)
c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey) c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey)
if bouncer.IPAddress == "" { if bouncer.IPAddress == "" {
err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil {
if err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"name": bouncer.Name,
}).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
return return
@ -189,12 +167,8 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
if bouncer.IPAddress != c.ClientIP() && bouncer.IPAddress != "" { if bouncer.IPAddress != c.ClientIP() && bouncer.IPAddress != "" {
log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress) log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress)
err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil {
if err != nil { logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"name": bouncer.Name,
}).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
return return
@ -202,21 +176,14 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
} }
useragent := strings.Split(c.Request.UserAgent(), "/") useragent := strings.Split(c.Request.UserAgent(), "/")
if len(useragent) != 2 { if len(useragent) != 2 {
log.WithFields(log.Fields{ logger.Warningf("bad user agent '%s'", c.Request.UserAgent())
"ip": c.ClientIP(),
"name": bouncer.Name,
}).Warningf("bad user agent '%s'", c.Request.UserAgent())
useragent = []string{c.Request.UserAgent(), "N/A"} useragent = []string{c.Request.UserAgent(), "N/A"}
} }
if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] {
if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil {
log.WithFields(log.Fields{ logger.Errorf("failed to update bouncer version and type: %s", err)
"ip": c.ClientIP(),
"name": bouncer.Name,
}).Errorf("failed to update bouncer version and type: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"})
c.Abort() c.Abort()
return return

View file

@ -14,7 +14,7 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) {
ret.JWT, err = NewJWT(dbClient) ret.JWT, err = NewJWT(dbClient)
if err != nil { if err != nil {
return &Middlewares{}, err return nil, err
} }
ret.APIKey = NewAPIKey(dbClient) ret.APIKey = NewAPIKey(dbClient)

View file

@ -13,7 +13,7 @@ import (
func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX) result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX)
if err != nil { if err != nil {
return &ent.Bouncer{}, errors.Wrapf(QueryFail, "select bouncer: %s", err) return nil, err
} }
return result, nil return result, nil
@ -22,7 +22,7 @@ func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) {
func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX) result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX)
if err != nil { if err != nil {
return &ent.Bouncer{}, errors.Wrapf(QueryFail, "select bouncer: %s", err) return nil, err
} }
return result, nil return result, nil
@ -31,7 +31,7 @@ func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) {
func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { func (c *Client) ListBouncers() ([]*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().All(c.CTX) result, err := c.Ent.Bouncer.Query().All(c.CTX)
if err != nil { if err != nil {
return []*ent.Bouncer{}, errors.Wrapf(QueryFail, "listing bouncer: %s", err) return nil, errors.Wrapf(QueryFail, "listing bouncers: %s", err)
} }
return result, nil return result, nil
} }

View file

@ -19,8 +19,8 @@ const CapiListsMachineID = types.ListOrigin
func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) {
hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if err != nil { if err != nil {
c.Log.Warningf("CreateMachine : %s", err) c.Log.Warningf("CreateMachine: %s", err)
return nil, errors.Wrap(HashError, "") return nil, HashError
} }
machineExist, err := c.Ent.Machine. machineExist, err := c.Ent.Machine.
@ -78,7 +78,7 @@ func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) {
func (c *Client) ListMachines() ([]*ent.Machine, error) { func (c *Client) ListMachines() ([]*ent.Machine, error) {
machines, err := c.Ent.Machine.Query().All(c.CTX) machines, err := c.Ent.Machine.Query().All(c.CTX)
if err != nil { if err != nil {
return []*ent.Machine{}, errors.Wrapf(QueryFail, "listing machines: %s", err) return nil, errors.Wrapf(QueryFail, "listing machines: %s", err)
} }
return machines, nil return machines, nil
} }
@ -101,7 +101,7 @@ func (c *Client) QueryPendingMachine() ([]*ent.Machine, error) {
machines, err = c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX) machines, err = c.Ent.Machine.Query().Where(machine.IsValidatedEQ(false)).All(c.CTX)
if err != nil { if err != nil {
c.Log.Warningf("QueryPendingMachine : %s", err) c.Log.Warningf("QueryPendingMachine : %s", err)
return []*ent.Machine{}, errors.Wrapf(QueryFail, "querying pending machines: %s", err) return nil, errors.Wrapf(QueryFail, "querying pending machines: %s", err)
} }
return machines, nil return machines, nil
} }
@ -190,12 +190,13 @@ func (c *Client) IsMachineRegistered(machineID string) (bool, error) {
return true, nil return true, nil
} }
if len(exist) > 1 { if len(exist) > 1 {
return false, fmt.Errorf("More than one item with the same machineID in database") return false, fmt.Errorf("more than one item with the same machineID in database")
} }
return false, nil return false, nil
} }
func (c *Client) QueryLastValidatedHeartbeatLT(t time.Time) ([]*ent.Machine, error) { func (c *Client) QueryLastValidatedHeartbeatLT(t time.Time) ([]*ent.Machine, error) {
return c.Ent.Machine.Query().Where(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)).All(c.CTX) return c.Ent.Machine.Query().Where(machine.LastHeartbeatLT(t), machine.IsValidatedEQ(true)).All(c.CTX)
} }