diff --git a/pkg/database/database.go b/pkg/database/database.go index 83027936f..b5bb29496 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -7,6 +7,7 @@ import ( "os" "time" + "entgo.io/ent/dialect" entsql "entgo.io/ent/dialect/sql" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database/ent" @@ -27,7 +28,7 @@ type Client struct { CanFlush bool } -func getEntDriver(dbtype string, dsn string, config *csconfig.DatabaseCfg) (*entsql.Driver, error) { +func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig.DatabaseCfg) (*entsql.Driver, error) { db, err := sql.Open(dbtype, dsn) if err != nil { return nil, err @@ -37,7 +38,7 @@ func getEntDriver(dbtype string, dsn string, config *csconfig.DatabaseCfg) (*ent config.MaxOpenConns = types.IntPtr(csconfig.DEFAULT_MAX_OPEN_CONNS) } db.SetMaxOpenConns(*config.MaxOpenConns) - drv := entsql.OpenDB(dbtype, db) + drv := entsql.OpenDB(dbdialect, db) return drv, nil } @@ -75,25 +76,25 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { return &Client{}, fmt.Errorf("unable to set perms on %s: %v", config.DbPath, err) } } - drv, err := getEntDriver("sqlite3", fmt.Sprintf("file:%s?_busy_timeout=100000&_fk=1", config.DbPath), config) + drv, err := getEntDriver("sqlite3", dialect.SQLite, fmt.Sprintf("file:%s?_busy_timeout=100000&_fk=1", config.DbPath), config) if err != nil { return &Client{}, errors.Wrapf(err, "failed opening connection to sqlite: %v", config.DbPath) } client = ent.NewClient(ent.Driver(drv), entOpt) case "mysql": - drv, err := getEntDriver("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName), config) + drv, err := getEntDriver("mysql", dialect.MySQL, fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=True", config.User, config.Password, config.Host, config.Port, config.DbName), config) if err != nil { return &Client{}, fmt.Errorf("failed opening connection to mysql: %v", err) } client = ent.NewClient(ent.Driver(drv), entOpt) case "postgres", "postgresql": - drv, err := getEntDriver("postgres", fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), config) + drv, err := getEntDriver("postgres", dialect.Postgres, fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=%s", config.Host, config.Port, config.User, config.DbName, config.Password, config.Sslmode), config) if err != nil { return &Client{}, fmt.Errorf("failed opening connection to postgresql: %v", err) } client = ent.NewClient(ent.Driver(drv), entOpt) case "pgx": - drv, err := getEntDriver("pgx", fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode), config) + drv, err := getEntDriver("pgx", dialect.Postgres, fmt.Sprintf("postgresql://%s:%s@%s:%d/%s?sslmode=%s", config.User, config.Password, config.Host, config.Port, config.DbName, config.Sslmode), config) if err != nil { return &Client{}, fmt.Errorf("failed opening connection to pgx: %v", err) }