diff --git a/pkg/csconfig/database.go b/pkg/csconfig/database.go index 65aa23e4a..af646ea15 100644 --- a/pkg/csconfig/database.go +++ b/pkg/csconfig/database.go @@ -4,6 +4,7 @@ import ( "fmt" "time" + "entgo.io/ent/dialect" "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" ) @@ -67,3 +68,48 @@ func (c *Config) LoadDBConfig() error { return nil } + +func (d *DatabaseCfg) ConnectionString() string { + connString := "" + switch d.Type { + case "sqlite": + var sqliteConnectionStringParameters string + if d.UseWal != nil && *d.UseWal { + sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1&_journal_mode=WAL" + } else { + sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1" + } + connString = fmt.Sprintf("file:%s?%s", d.DbPath, sqliteConnectionStringParameters) + case "mysql": + if d.isSocketConfig() { + connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", d.User, d.Password, d.DbPath, d.DbName) + } else { + connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", d.User, d.Password, d.Host, d.Port, d.DbName) + } + case "postgres", "postgresql", "pgx": + if d.isSocketConfig() { + connString = fmt.Sprintf("host=%s user=%s dbname=%s password=%s", d.DbPath, d.User, d.DbName, d.Password) + } else { + connString = fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", d.Host, d.Port, d.User, d.DbName, d.Password, d.Sslmode) + } + } + return connString +} + +func (d *DatabaseCfg) ConnectionDialect() (string, string, error) { + switch d.Type { + case "sqlite": + return "sqlite3", dialect.SQLite, nil + case "mysql": + return "mysql", dialect.MySQL, nil + case "postgres", "postgresql": + return "postgres", dialect.Postgres, nil + case "pgx": + return "pgx", dialect.Postgres, nil + } + return "", "", fmt.Errorf("unknown database type '%s'", d.Type) +} + +func (d *DatabaseCfg) isSocketConfig() bool { + return d.Host == "" && d.Port == 0 && d.DbPath != "" +} diff --git a/pkg/database/database.go b/pkg/database/database.go index 5dc247616..e6427ab9b 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -7,7 +7,6 @@ import ( "os" "time" - "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database/ent" @@ -61,8 +60,11 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { entLogger := clog.WithField("context", "ent") entOpt := ent.Log(entLogger.Debug) - switch config.Type { - case "sqlite": + typ, dia, err := config.ConnectionDialect() + if err != nil { + return &Client{}, err //unsupported database caught here + } + if config.Type == "sqlite" { /*if it's the first startup, we want to touch and chmod file*/ if _, err := os.Stat(config.DbPath); os.IsNotExist(err) { f, err := os.OpenFile(config.DbPath, os.O_CREATE|os.O_RDWR, 0600) @@ -77,45 +79,12 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { if err := setFilePerm(config.DbPath, 0640); err != nil { return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err) } - var sqliteConnectionStringParameters string - if config.UseWal != nil && *config.UseWal { - sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1&_journal_mode=WAL" - } else { - sqliteConnectionStringParameters = "_busy_timeout=100000&_fk=1" - } - drv, err := getEntDriver("sqlite3", dialect.SQLite, fmt.Sprintf("file:%s?%s", config.DbPath, sqliteConnectionStringParameters), config) - if err != nil { - return &Client{}, errors.Wrapf(err, "failed opening connection to sqlite: %v", config.DbPath) - } - client = ent.NewClient(ent.Driver(drv), entOpt) - case "mysql": - connString := "" - if config.Host == "" && config.Port == 0 && config.DbPath != "" { - connString = fmt.Sprintf("%s:%s@unix(%s)/%s?parseTime=True", config.User, config.Password, config.DbPath, config.DbName) - } else { - connString = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName) - } - drv, err := getEntDriver("mysql", dialect.MySQL, connString, config) - if err != nil { - return &Client{}, fmt.Errorf("failed opening connection to mysql: %v", err) - } - client = ent.NewClient(ent.Driver(drv), entOpt) - case "postgres", "postgresql": - drv, err := getEntDriver("postgres", dialect.Postgres, fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), config) - if err != nil { - return &Client{}, fmt.Errorf("failed opening connection to postgresql: %v", err) - } - client = ent.NewClient(ent.Driver(drv), entOpt) - case "pgx": - drv, err := getEntDriver("pgx", dialect.Postgres, fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode), config) - if err != nil { - return &Client{}, fmt.Errorf("failed opening connection to pgx: %v", err) - } - client = ent.NewClient(ent.Driver(drv), entOpt) - default: - return &Client{}, fmt.Errorf("unknown database type '%s'", config.Type) } - + drv, err := getEntDriver(typ, dia, config.ConnectionString(), config) + if err != nil { + return &Client{}, fmt.Errorf("failed opening connection to %s: %v", config.Type, err) + } + client = ent.NewClient(ent.Driver(drv), entOpt) if config.LogLevel != nil && *config.LogLevel >= log.DebugLevel { clog.Debugf("Enabling request debug") client = client.Debug()