diff --git a/config/config.yaml b/config/config.yaml index d5c0a11a2..38db1419e 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -22,6 +22,7 @@ db_config: log_level: info type: sqlite db_path: /var/lib/crowdsec/data/crowdsec.db + #max_open_conns: 100 #user: #password: #db_name: diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 36392275e..2778f4e0d 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -100,11 +100,13 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { useragent = []string{c.Request.UserAgent(), "N/A"} } - if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { - log.Errorf("failed to update bouncer version and type from '%s' (%s): %s", c.Request.UserAgent(), c.ClientIP(), err) - c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) - c.Abort() - return + if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { + if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { + log.Errorf("failed to update bouncer version and type from '%s' (%s): %s", c.Request.UserAgent(), c.ClientIP(), err) + c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) + c.Abort() + return + } } c.Next() diff --git a/pkg/csconfig/api_test.go b/pkg/csconfig/api_test.go index 9d8b9ea93..9f4e8794e 100644 --- a/pkg/csconfig/api_test.go +++ b/pkg/csconfig/api_test.go @@ -204,8 +204,9 @@ func TestLoadAPIServer(t *testing.T) { ListenURI: "http://crowdsec.api", TLS: nil, DbConfig: &DatabaseCfg{ - DbPath: "./tests/test.db", - Type: "sqlite", + DbPath: "./tests/test.db", + Type: "sqlite", + MaxOpenConns: types.IntPtr(DEFAULT_MAX_OPEN_CONNS), }, ConsoleConfigPath: DefaultConfigPath("console.yaml"), ConsoleConfig: &ConsoleConfig{ diff --git a/pkg/csconfig/config.go b/pkg/csconfig/config.go index ae44048b9..e0ea65843 100644 --- a/pkg/csconfig/config.go +++ b/pkg/csconfig/config.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" + "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -107,8 +108,9 @@ func NewDefaultConfig() *Config { } dbConfig := DatabaseCfg{ - Type: "sqlite", - DbPath: DefaultDataPath("crowdsec.db"), + Type: "sqlite", + DbPath: DefaultDataPath("crowdsec.db"), + MaxOpenConns: types.IntPtr(DEFAULT_MAX_OPEN_CONNS), } globalCfg := Config{ diff --git a/pkg/csconfig/database.go b/pkg/csconfig/database.go index 9449a7833..979b0da8c 100644 --- a/pkg/csconfig/database.go +++ b/pkg/csconfig/database.go @@ -3,20 +3,24 @@ package csconfig import ( "fmt" + "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" ) +var DEFAULT_MAX_OPEN_CONNS = 100 + type DatabaseCfg struct { - User string `yaml:"user"` - Password string `yaml:"password"` - DbName string `yaml:"db_name"` - Sslmode string `yaml:"sslmode"` - Host string `yaml:"host"` - Port int `yaml:"port"` - DbPath string `yaml:"db_path"` - Type string `yaml:"type"` - Flush *FlushDBCfg `yaml:"flush"` - LogLevel *log.Level `yaml:"log_level"` + User string `yaml:"user"` + Password string `yaml:"password"` + DbName string `yaml:"db_name"` + Sslmode string `yaml:"sslmode"` + Host string `yaml:"host"` + Port int `yaml:"port"` + DbPath string `yaml:"db_path"` + Type string `yaml:"type"` + Flush *FlushDBCfg `yaml:"flush"` + LogLevel *log.Level `yaml:"log_level"` + MaxOpenConns *int `yaml:"max_open_conns,omitempty"` } type FlushDBCfg struct { @@ -37,5 +41,8 @@ func (c *Config) LoadDBConfig() error { c.API.Server.DbConfig = c.DbConfig } + if c.DbConfig.MaxOpenConns == nil { + c.DbConfig.MaxOpenConns = types.IntPtr(DEFAULT_MAX_OPEN_CONNS) + } return nil } diff --git a/pkg/csconfig/database_test.go b/pkg/csconfig/database_test.go index b51b6ac10..b029f3883 100644 --- a/pkg/csconfig/database_test.go +++ b/pkg/csconfig/database_test.go @@ -5,6 +5,7 @@ import ( "strings" "testing" + "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/stretchr/testify/assert" ) @@ -19,8 +20,9 @@ func TestLoadDBConfig(t *testing.T) { name: "basic valid configuration", Input: &Config{ DbConfig: &DatabaseCfg{ - Type: "sqlite", - DbPath: "./tests/test.db", + Type: "sqlite", + DbPath: "./tests/test.db", + MaxOpenConns: types.IntPtr(10), }, Cscli: &CscliCfg{}, API: &APICfg{ @@ -28,8 +30,9 @@ func TestLoadDBConfig(t *testing.T) { }, }, expectedResult: &DatabaseCfg{ - Type: "sqlite", - DbPath: "./tests/test.db", + Type: "sqlite", + DbPath: "./tests/test.db", + MaxOpenConns: types.IntPtr(10), }, }, { diff --git a/pkg/database/database.go b/pkg/database/database.go index 2cd6831b5..83027936f 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -7,7 +7,6 @@ import ( "os" "time" - "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database/ent" @@ -28,6 +27,20 @@ type Client struct { CanFlush bool } +func getEntDriver(dbtype string, dsn string, config *csconfig.DatabaseCfg) (*entsql.Driver, error) { + db, err := sql.Open(dbtype, dsn) + if err != nil { + return nil, err + } + if config.MaxOpenConns == nil { + log.Warningf("MaxOpenConns is 0, defaulting to %d", csconfig.DEFAULT_MAX_OPEN_CONNS) + config.MaxOpenConns = types.IntPtr(csconfig.DEFAULT_MAX_OPEN_CONNS) + } + db.SetMaxOpenConns(*config.MaxOpenConns) + drv := entsql.OpenDB(dbtype, db) + return drv, nil +} + func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { var client *ent.Client var err error @@ -62,27 +75,28 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err) } } - client, err = ent.Open("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=100000&_fk=1", config.DbPath), entOpt) + drv, err := getEntDriver("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=100000&_fk=1", config.DbPath), config) if err != nil { - return &Client{}, fmt.Errorf("failed opening connection to sqlite: %v", err) + return &Client{}, errors.Wrapf(err, "failed opening connection to sqlite: %v", config.DbPath) } + client = ent.NewClient(ent.Driver(drv), entOpt) case "mysql": - client, err = ent.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName), entOpt) + drv, err := getEntDriver("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName), config) if err != nil { return &Client{}, fmt.Errorf("failed opening connection to mysql: %v", err) } + client = ent.NewClient(ent.Driver(drv), entOpt) case "postgres", "postgresql": - client, err = ent.Open("postgres", fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), entOpt) + drv, err := getEntDriver("postgres", fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), config) if err != nil { - return &Client{}, fmt.Errorf("failed opening connection to postgres: %v", err) + return &Client{}, fmt.Errorf("failed opening connection to postgresql: %v", err) } + client = ent.NewClient(ent.Driver(drv), entOpt) case "pgx": - db, err := sql.Open("pgx", fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode)) + drv, err := getEntDriver("pgx", fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode), config) if err != nil { return &Client{}, fmt.Errorf("failed opening connection to pgx: %v", err) } - // Create an ent.Driver from `db`. - drv := entsql.OpenDB(dialect.Postgres, db) client = ent.NewClient(ent.Driver(drv), entOpt) default: return &Client{}, fmt.Errorf("unknown database type") diff --git a/pkg/database/ent/migrate/schema.go b/pkg/database/ent/migrate/schema.go index 3909c1d4c..f6688aa30 100644 --- a/pkg/database/ent/migrate/schema.go +++ b/pkg/database/ent/migrate/schema.go @@ -108,6 +108,13 @@ var ( OnDelete: schema.Cascade, }, }, + Indexes: []*schema.Index{ + { + Name: "decision_start_ip_end_ip", + Unique: false, + Columns: []*schema.Column{DecisionsColumns[6], DecisionsColumns[7]}, + }, + }, } // EventsColumns holds the columns for the "events" table. EventsColumns = []*schema.Column{ diff --git a/pkg/database/ent/schema/decision.go b/pkg/database/ent/schema/decision.go index ebae2377f..02da5c0cc 100644 --- a/pkg/database/ent/schema/decision.go +++ b/pkg/database/ent/schema/decision.go @@ -4,6 +4,7 @@ import ( "entgo.io/ent" "entgo.io/ent/schema/edge" "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -44,3 +45,9 @@ func (Decision) Edges() []ent.Edge { Unique(), } } + +func (Decision) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("start_ip", "end_ip"), + } +} diff --git a/pkg/types/utils.go b/pkg/types/utils.go index fb04c65bd..55c625db0 100644 --- a/pkg/types/utils.go +++ b/pkg/types/utils.go @@ -219,6 +219,10 @@ func StrPtr(s string) *string { return &s } +func IntPtr(i int) *int { + return &i +} + func Int32Ptr(i int32) *int32 { return &i }