From 15542b78fbf12b05f2d9697994d9b8f081b34e48 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Sun, 26 Nov 2023 22:30:03 +0100 Subject: [PATCH] refact BulkDeleteDecisions (#2308) Code cleanup and de-duplication. --- pkg/apiserver/apic.go | 86 +++++++++++++++++++++------------------ pkg/database/decisions.go | 81 +++++++++++++++++------------------- 2 files changed, 83 insertions(+), 84 deletions(-) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index a63c8e2aa..a199e2892 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -7,7 +7,6 @@ import ( "net" "net/http" "net/url" - "slices" "strconv" "strings" "sync" @@ -17,6 +16,7 @@ import ( "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" + "slices" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/trace" @@ -383,19 +383,16 @@ func (a *apic) CAPIPullIsOld() (bool, error) { } func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delete_counters map[string]map[string]int) (int, error) { - var filter map[string][]string - var nbDeleted int + nbDeleted := 0 for _, decision := range deletedDecisions { - if strings.ToLower(*decision.Scope) == "ip" { - filter = make(map[string][]string, 1) - filter["value"] = []string{*decision.Value} - } else { - filter = make(map[string][]string, 3) - filter["value"] = []string{*decision.Value} + filter := map[string][]string{ + "value": {*decision.Value}, + "origin": {*decision.Origin}, + } + if strings.ToLower(*decision.Scope) != "ip" { filter["type"] = []string{*decision.Type} filter["scopes"] = []string{*decision.Scope} } - filter["origin"] = []string{*decision.Origin} dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter) if err != nil { @@ -412,20 +409,17 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet } func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, delete_counters map[string]map[string]int) (int, error) { - var filter map[string][]string var nbDeleted int for _, decisions := range deletedDecisions { scope := decisions.Scope for _, decision := range decisions.Decisions { - if strings.ToLower(*scope) == "ip" { - filter = make(map[string][]string, 1) - filter["value"] = []string{decision} - } else { - filter = make(map[string][]string, 2) - filter["value"] = []string{decision} + filter := map[string][]string{ + "value": {decision}, + "origin": {types.CAPIOrigin}, + } + if strings.ToLower(*scope) != "ip" { filter["scopes"] = []string{*scope} } - filter["origin"] = []string{types.CAPIOrigin} dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter) if err != nil { @@ -479,30 +473,42 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert { } func createAlertForDecision(decision *models.Decision) *models.Alert { - newAlert := &models.Alert{} - newAlert.Source = &models.Source{} - newAlert.Source.Scope = ptr.Of("") - if *decision.Origin == types.CAPIOrigin { //to make things more user friendly, we replace CAPI with community-blocklist - newAlert.Scenario = ptr.Of(types.CAPIOrigin) - newAlert.Source.Scope = ptr.Of(types.CAPIOrigin) - } else if *decision.Origin == types.ListOrigin { - newAlert.Scenario = ptr.Of(*decision.Scenario) - newAlert.Source.Scope = ptr.Of(types.ListOrigin) - } else { + var ( + scenario string + scope string + ) + + switch *decision.Origin { + case types.CAPIOrigin: + scenario = types.CAPIOrigin + scope = types.CAPIOrigin + case types.ListOrigin: + scenario = *decision.Scenario + scope = types.ListOrigin + default: + // XXX: this or nil? + scenario = "" + scope = "" log.Warningf("unknown origin %s", *decision.Origin) } - newAlert.Message = ptr.Of("") - newAlert.Source.Value = ptr.Of("") - newAlert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) - newAlert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) - newAlert.Capacity = ptr.Of(int32(0)) - newAlert.Simulated = ptr.Of(false) - newAlert.EventsCount = ptr.Of(int32(0)) - newAlert.Leakspeed = ptr.Of("") - newAlert.ScenarioHash = ptr.Of("") - newAlert.ScenarioVersion = ptr.Of("") - newAlert.MachineID = database.CapiMachineID - return newAlert + + return &models.Alert{ + Source: &models.Source{ + Scope: ptr.Of(scope), + Value: ptr.Of(""), + }, + Scenario: ptr.Of(scenario), + Message: ptr.Of(""), + StartAt: ptr.Of(time.Now().UTC().Format(time.RFC3339)), + StopAt: ptr.Of(time.Now().UTC().Format(time.RFC3339)), + Capacity: ptr.Of(int32(0)), + Simulated: ptr.Of(false), + EventsCount: ptr.Of(int32(0)), + Leakspeed: ptr.Of(""), + ScenarioHash: ptr.Of(""), + ScenarioVersion: ptr.Of(""), + MachineID: database.CapiMachineID, + } } // This function takes in list of parent alerts and decisions and then pairs them up. diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index cf4b9c966..c4ea0bb11 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -9,6 +9,8 @@ import ( "entgo.io/ent/dialect/sql" "github.com/pkg/errors" + "github.com/crowdsecurity/go-cs-lib/slicetools" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" @@ -23,7 +25,6 @@ type DecisionsByScenario struct { } func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) { - var err error var start_ip, start_sfx, end_ip, end_sfx int64 var ip_sz int @@ -545,55 +546,39 @@ func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (stri // BulkDeleteDecisions set the expiration of a bulk of decisions to now() or hard deletes them. // We are doing it this way so we can return impacted decisions for sync with CAPI/PAPI -func (c *Client) BulkDeleteDecisions(DecisionsToDelete []*ent.Decision, softDelete bool) (int, error) { - bulkSize := 256 //scientifically proven to be the best value for bulk delete - idsToDelete := make([]int, 0, bulkSize) - totalUpdates := 0 - for i := 0; i < len(DecisionsToDelete); i++ { - idsToDelete = append(idsToDelete, DecisionsToDelete[i].ID) - if len(idsToDelete) == bulkSize { +func (c *Client) BulkDeleteDecisions(decisionsToDelete []*ent.Decision, softDelete bool) (int, error) { + const bulkSize = 256 //scientifically proven to be the best value for bulk delete - if softDelete { - nbUpdates, err := c.Ent.Decision.Update().Where( - decision.IDIn(idsToDelete...), - ).SetUntil(time.Now().UTC()).Save(c.CTX) - if err != nil { - return totalUpdates, errors.Wrap(err, "soft delete decisions with provided filter") - } - totalUpdates += nbUpdates - } else { - nbUpdates, err := c.Ent.Decision.Delete().Where( - decision.IDIn(idsToDelete...), - ).Exec(c.CTX) - if err != nil { - return totalUpdates, errors.Wrap(err, "hard delete decisions with provided filter") - } - totalUpdates += nbUpdates - } - idsToDelete = make([]int, 0, bulkSize) - } + var ( + nbUpdates int + err error + totalUpdates = 0 + ) + + idsToDelete := make([]int, len(decisionsToDelete)) + for i, decision := range decisionsToDelete { + idsToDelete[i] = decision.ID } - if len(idsToDelete) > 0 { + for _, chunk := range slicetools.Chunks(idsToDelete, bulkSize) { if softDelete { - nbUpdates, err := c.Ent.Decision.Update().Where( - decision.IDIn(idsToDelete...), + nbUpdates, err = c.Ent.Decision.Update().Where( + decision.IDIn(chunk...), ).SetUntil(time.Now().UTC()).Save(c.CTX) if err != nil { - return totalUpdates, errors.Wrap(err, "soft delete decisions with provided filter") + return totalUpdates, fmt.Errorf("soft delete decisions with provided filter: %w", err) } - totalUpdates += nbUpdates } else { - nbUpdates, err := c.Ent.Decision.Delete().Where( - decision.IDIn(idsToDelete...), + nbUpdates, err = c.Ent.Decision.Delete().Where( + decision.IDIn(chunk...), ).Exec(c.CTX) if err != nil { - return totalUpdates, errors.Wrap(err, "hard delete decisions with provided filter") + return totalUpdates, fmt.Errorf("hard delete decisions with provided filter: %w", err) } - totalUpdates += nbUpdates } - + totalUpdates += nbUpdates } + return totalUpdates, nil } @@ -601,6 +586,7 @@ func (c *Client) BulkDeleteDecisions(DecisionsToDelete []*ent.Decision, softDele func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, error) { toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) + // XXX: do we want 500 or 404 here? if err != nil || len(toUpdate) == 0 { c.Log.Warningf("SoftDeleteDecisionByID : %v (nb soft deleted: %d)", err, len(toUpdate)) return 0, nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID) @@ -609,6 +595,7 @@ func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, e if len(toUpdate) == 0 { return 0, nil, ItemNotFound } + count, err := c.BulkDeleteDecisions(toUpdate, true) return count, toUpdate, err } @@ -639,10 +626,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) { } func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) { - var err error - var start_ip, start_sfx, end_ip, end_sfx int64 - var ip_sz, count int - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue) + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue) if err != nil { return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) @@ -652,11 +636,13 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim decisions := c.Ent.Decision.Query().Where( decision.CreatedAtGT(since), ) + decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) if err != nil { return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") } - count, err = decisions.Count(c.CTX) + + count, err := decisions.Count(c.CTX) if err != nil { return 0, errors.Wrapf(err, "fail to count decisions") } @@ -681,7 +667,10 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz decision.IPSizeEQ(int64(ip_sz)), )) } - } else if ip_sz == 16 { + return decisions, nil + } + + if ip_sz == 16 { /*decision contains {start_ip,end_ip}*/ if contains { decisions = decisions.Where(decision.And( @@ -733,9 +722,13 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz ), )) } - } else if ip_sz != 0 { + return decisions, nil + } + + if ip_sz != 0 { return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz) } + return decisions, nil }