diff --git a/pkg/csconfig/database.go b/pkg/csconfig/database.go index 354f18c11..91df34f20 100644 --- a/pkg/csconfig/database.go +++ b/pkg/csconfig/database.go @@ -12,19 +12,22 @@ import ( var DEFAULT_MAX_OPEN_CONNS = 100 +const defaultDecisionBulkSize = 50 + type DatabaseCfg struct { - User string `yaml:"user"` - Password string `yaml:"password"` - DbName string `yaml:"db_name"` - Sslmode string `yaml:"sslmode"` - Host string `yaml:"host"` - Port int `yaml:"port"` - DbPath string `yaml:"db_path"` - Type string `yaml:"type"` - Flush *FlushDBCfg `yaml:"flush"` - LogLevel *log.Level `yaml:"log_level"` - MaxOpenConns *int `yaml:"max_open_conns,omitempty"` - UseWal *bool `yaml:"use_wal,omitempty"` + User string `yaml:"user"` + Password string `yaml:"password"` + DbName string `yaml:"db_name"` + Sslmode string `yaml:"sslmode"` + Host string `yaml:"host"` + Port int `yaml:"port"` + DbPath string `yaml:"db_path"` + Type string `yaml:"type"` + Flush *FlushDBCfg `yaml:"flush"` + LogLevel *log.Level `yaml:"log_level"` + MaxOpenConns *int `yaml:"max_open_conns,omitempty"` + UseWal *bool `yaml:"use_wal,omitempty"` + DecisionBulkSize int `yaml:"decision_bulk_size,omitempty"` } type AuthGCCfg struct { @@ -60,11 +63,15 @@ func (c *Config) LoadDBConfig() error { c.DbConfig.MaxOpenConns = ptr.Of(DEFAULT_MAX_OPEN_CONNS) } + if c.DbConfig.DecisionBulkSize == 0 { + log.Tracef("No decision_bulk_size value provided, using default value of %d", defaultDecisionBulkSize) + c.DbConfig.DecisionBulkSize = defaultDecisionBulkSize + } + if c.DbConfig.Type == "sqlite" { if c.DbConfig.UseWal == nil { log.Warning("You are using sqlite without WAL, this can have a performance impact. If you do not store the database in a network share, set db_config.use_wal to true. Set explicitly to false to disable this warning.") } - } return nil diff --git a/pkg/csconfig/database_test.go b/pkg/csconfig/database_test.go index 017014b04..29aee2640 100644 --- a/pkg/csconfig/database_test.go +++ b/pkg/csconfig/database_test.go @@ -34,6 +34,7 @@ func TestLoadDBConfig(t *testing.T) { Type: "sqlite", DbPath: "./tests/test.db", MaxOpenConns: ptr.Of(10), + DecisionBulkSize: defaultDecisionBulkSize, }, }, { diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index d7b7db08f..fcc2cdfdc 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -30,7 +30,6 @@ const ( paginationSize = 100 // used to queryAlert to avoid 'too many SQL variable' defaultLimit = 100 // default limit of element to returns when query alerts bulkSize = 50 // bulk size when create alerts - decisionBulkSize = 50 ) func formatAlertCN(source models.Source) string { @@ -192,7 +191,7 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) log.Debugf("Adding %d missing decisions to alert %s", len(missingDecisions), foundAlert.UUID) decisions := make([]*ent.Decision, 0) - decisionBulk := make([]*ent.DecisionCreate, 0, decisionBulkSize) + decisionBulk := make([]*ent.DecisionCreate, 0, c.decisionBulkSize) for i, decisionItem := range missingDecisions { var start_ip, start_sfx, end_ip, end_sfx int64 @@ -234,17 +233,17 @@ func (c *Client) CreateOrUpdateAlert(machineID string, alertItem *models.Alert) SetUUID(decisionItem.UUID) decisionBulk = append(decisionBulk, decisionCreate) - if len(decisionBulk) == decisionBulkSize { + if len(decisionBulk) == c.decisionBulkSize { decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) if err != nil { return "", errors.Wrapf(BulkError, "creating alert decisions: %s", err) } decisions = append(decisions, decisionsCreateRet...) - if len(missingDecisions)-i <= decisionBulkSize { + if len(missingDecisions)-i <= c.decisionBulkSize { decisionBulk = make([]*ent.DecisionCreate, 0, (len(missingDecisions) - i)) } else { - decisionBulk = make([]*ent.DecisionCreate, 0, decisionBulkSize) + decisionBulk = make([]*ent.DecisionCreate, 0, c.decisionBulkSize) } } } @@ -353,8 +352,8 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in if err != nil { return 0, 0, 0, errors.Wrapf(BulkError, "error creating transaction : %s", err) } - decisionBulk := make([]*ent.DecisionCreate, 0, decisionBulkSize) - valueList := make([]string, 0, decisionBulkSize) + decisionBulk := make([]*ent.DecisionCreate, 0, c.decisionBulkSize) + valueList := make([]string, 0, c.decisionBulkSize) DecOrigin := CapiMachineID if *alertItem.Decisions[0].Origin == CapiMachineID || *alertItem.Decisions[0].Origin == CapiListsMachineID { DecOrigin = *alertItem.Decisions[0].Origin @@ -418,7 +417,7 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in } valueList = append(valueList, *decisionItem.Value) - if len(decisionBulk) == decisionBulkSize { + if len(decisionBulk) == c.decisionBulkSize { insertedDecisions, err := txClient.Decision.CreateBulk(decisionBulk...).Save(c.CTX) if err != nil { @@ -446,12 +445,12 @@ func (c *Client) UpdateCommunityBlocklist(alertItem *models.Alert) (int, int, in } deleted += deletedDecisions - if len(alertItem.Decisions)-i <= decisionBulkSize { + if len(alertItem.Decisions)-i <= c.decisionBulkSize { decisionBulk = make([]*ent.DecisionCreate, 0, (len(alertItem.Decisions) - i)) valueList = make([]string, 0, (len(alertItem.Decisions) - i)) } else { - decisionBulk = make([]*ent.DecisionCreate, 0, decisionBulkSize) - valueList = make([]string, 0, decisionBulkSize) + decisionBulk = make([]*ent.DecisionCreate, 0, c.decisionBulkSize) + valueList = make([]string, 0, c.decisionBulkSize) } } @@ -631,7 +630,7 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ decisions = make([]*ent.Decision, 0) if len(alertItem.Decisions) > 0 { - decisionBulk := make([]*ent.DecisionCreate, 0, decisionBulkSize) + decisionBulk := make([]*ent.DecisionCreate, 0, c.decisionBulkSize) for i, decisionItem := range alertItem.Decisions { var start_ip, start_sfx, end_ip, end_sfx int64 var sz int @@ -665,17 +664,17 @@ func (c *Client) CreateAlertBulk(machineId string, alertList []*models.Alert) ([ SetUUID(decisionItem.UUID) decisionBulk = append(decisionBulk, decisionCreate) - if len(decisionBulk) == decisionBulkSize { + if len(decisionBulk) == c.decisionBulkSize { decisionsCreateRet, err := c.Ent.Decision.CreateBulk(decisionBulk...).Save(c.CTX) if err != nil { return nil, errors.Wrapf(BulkError, "creating alert decisions: %s", err) } decisions = append(decisions, decisionsCreateRet...) - if len(alertItem.Decisions)-i <= decisionBulkSize { + if len(alertItem.Decisions)-i <= c.decisionBulkSize { decisionBulk = make([]*ent.DecisionCreate, 0, (len(alertItem.Decisions) - i)) } else { - decisionBulk = make([]*ent.DecisionCreate, 0, decisionBulkSize) + decisionBulk = make([]*ent.DecisionCreate, 0, c.decisionBulkSize) } } } diff --git a/pkg/database/database.go b/pkg/database/database.go index 9b4d7e41e..18b5dbf38 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -22,12 +22,13 @@ import ( ) type Client struct { - Ent *ent.Client - CTX context.Context - Log *log.Logger - CanFlush bool - Type string - WalMode *bool + Ent *ent.Client + CTX context.Context + Log *log.Logger + CanFlush bool + Type string + WalMode *bool + decisionBulkSize int } func getEntDriver(dbtype string, dbdialect string, dsn string, config *csconfig.DatabaseCfg) (*entsql.Driver, error) { @@ -93,7 +94,16 @@ func NewClient(config *csconfig.DatabaseCfg) (*Client, error) { if err = client.Schema.Create(context.Background()); err != nil { return nil, fmt.Errorf("failed creating schema resources: %v", err) } - return &Client{Ent: client, CTX: context.Background(), Log: clog, CanFlush: true, Type: config.Type, WalMode: config.UseWal}, nil + + return &Client{ + Ent: client, + CTX: context.Background(), + Log: clog, + CanFlush: true, + Type: config.Type, + WalMode: config.UseWal, + decisionBulkSize: config.DecisionBulkSize, + }, nil } func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Scheduler, error) {