refact BulkDeleteDecisions (#2308)

Code cleanup and de-duplication.
This commit is contained in:
mmetc 2023-11-26 22:30:03 +01:00 committed by GitHub
parent b164373997
commit 15542b78fb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 84 deletions

View file

@ -7,7 +7,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -17,6 +16,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/tomb.v2" "gopkg.in/tomb.v2"
"slices"
"github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/go-cs-lib/trace"
@ -383,19 +383,16 @@ func (a *apic) CAPIPullIsOld() (bool, error) {
} }
func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delete_counters map[string]map[string]int) (int, error) { func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delete_counters map[string]map[string]int) (int, error) {
var filter map[string][]string nbDeleted := 0
var nbDeleted int
for _, decision := range deletedDecisions { for _, decision := range deletedDecisions {
if strings.ToLower(*decision.Scope) == "ip" { filter := map[string][]string{
filter = make(map[string][]string, 1) "value": {*decision.Value},
filter["value"] = []string{*decision.Value} "origin": {*decision.Origin},
} else { }
filter = make(map[string][]string, 3) if strings.ToLower(*decision.Scope) != "ip" {
filter["value"] = []string{*decision.Value}
filter["type"] = []string{*decision.Type} filter["type"] = []string{*decision.Type}
filter["scopes"] = []string{*decision.Scope} filter["scopes"] = []string{*decision.Scope}
} }
filter["origin"] = []string{*decision.Origin}
dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter) dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter)
if err != nil { if err != nil {
@ -412,20 +409,17 @@ func (a *apic) HandleDeletedDecisions(deletedDecisions []*models.Decision, delet
} }
func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, delete_counters map[string]map[string]int) (int, error) { func (a *apic) HandleDeletedDecisionsV3(deletedDecisions []*modelscapi.GetDecisionsStreamResponseDeletedItem, delete_counters map[string]map[string]int) (int, error) {
var filter map[string][]string
var nbDeleted int var nbDeleted int
for _, decisions := range deletedDecisions { for _, decisions := range deletedDecisions {
scope := decisions.Scope scope := decisions.Scope
for _, decision := range decisions.Decisions { for _, decision := range decisions.Decisions {
if strings.ToLower(*scope) == "ip" { filter := map[string][]string{
filter = make(map[string][]string, 1) "value": {decision},
filter["value"] = []string{decision} "origin": {types.CAPIOrigin},
} else { }
filter = make(map[string][]string, 2) if strings.ToLower(*scope) != "ip" {
filter["value"] = []string{decision}
filter["scopes"] = []string{*scope} filter["scopes"] = []string{*scope}
} }
filter["origin"] = []string{types.CAPIOrigin}
dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter) dbCliRet, _, err := a.dbClient.SoftDeleteDecisionsWithFilter(filter)
if err != nil { if err != nil {
@ -479,30 +473,42 @@ func createAlertsForDecisions(decisions []*models.Decision) []*models.Alert {
} }
func createAlertForDecision(decision *models.Decision) *models.Alert { func createAlertForDecision(decision *models.Decision) *models.Alert {
newAlert := &models.Alert{} var (
newAlert.Source = &models.Source{} scenario string
newAlert.Source.Scope = ptr.Of("") scope string
if *decision.Origin == types.CAPIOrigin { //to make things more user friendly, we replace CAPI with community-blocklist )
newAlert.Scenario = ptr.Of(types.CAPIOrigin)
newAlert.Source.Scope = ptr.Of(types.CAPIOrigin) switch *decision.Origin {
} else if *decision.Origin == types.ListOrigin { case types.CAPIOrigin:
newAlert.Scenario = ptr.Of(*decision.Scenario) scenario = types.CAPIOrigin
newAlert.Source.Scope = ptr.Of(types.ListOrigin) scope = types.CAPIOrigin
} else { case types.ListOrigin:
scenario = *decision.Scenario
scope = types.ListOrigin
default:
// XXX: this or nil?
scenario = ""
scope = ""
log.Warningf("unknown origin %s", *decision.Origin) log.Warningf("unknown origin %s", *decision.Origin)
} }
newAlert.Message = ptr.Of("")
newAlert.Source.Value = ptr.Of("") return &models.Alert{
newAlert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) Source: &models.Source{
newAlert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339)) Scope: ptr.Of(scope),
newAlert.Capacity = ptr.Of(int32(0)) Value: ptr.Of(""),
newAlert.Simulated = ptr.Of(false) },
newAlert.EventsCount = ptr.Of(int32(0)) Scenario: ptr.Of(scenario),
newAlert.Leakspeed = ptr.Of("") Message: ptr.Of(""),
newAlert.ScenarioHash = ptr.Of("") StartAt: ptr.Of(time.Now().UTC().Format(time.RFC3339)),
newAlert.ScenarioVersion = ptr.Of("") StopAt: ptr.Of(time.Now().UTC().Format(time.RFC3339)),
newAlert.MachineID = database.CapiMachineID Capacity: ptr.Of(int32(0)),
return newAlert Simulated: ptr.Of(false),
EventsCount: ptr.Of(int32(0)),
Leakspeed: ptr.Of(""),
ScenarioHash: ptr.Of(""),
ScenarioVersion: ptr.Of(""),
MachineID: database.CapiMachineID,
}
} }
// This function takes in list of parent alerts and decisions and then pairs them up. // This function takes in list of parent alerts and decisions and then pairs them up.

View file

@ -9,6 +9,8 @@ import (
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/crowdsecurity/go-cs-lib/slicetools"
"github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate" "github.com/crowdsecurity/crowdsec/pkg/database/ent/predicate"
@ -23,7 +25,6 @@ type DecisionsByScenario struct {
} }
func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) { func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string][]string) (*ent.DecisionQuery, error) {
var err error var err error
var start_ip, start_sfx, end_ip, end_sfx int64 var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int var ip_sz int
@ -545,55 +546,39 @@ func (c *Client) SoftDeleteDecisionsWithFilter(filter map[string][]string) (stri
// BulkDeleteDecisions set the expiration of a bulk of decisions to now() or hard deletes them. // BulkDeleteDecisions set the expiration of a bulk of decisions to now() or hard deletes them.
// We are doing it this way so we can return impacted decisions for sync with CAPI/PAPI // We are doing it this way so we can return impacted decisions for sync with CAPI/PAPI
func (c *Client) BulkDeleteDecisions(DecisionsToDelete []*ent.Decision, softDelete bool) (int, error) { func (c *Client) BulkDeleteDecisions(decisionsToDelete []*ent.Decision, softDelete bool) (int, error) {
bulkSize := 256 //scientifically proven to be the best value for bulk delete const bulkSize = 256 //scientifically proven to be the best value for bulk delete
idsToDelete := make([]int, 0, bulkSize)
totalUpdates := 0
for i := 0; i < len(DecisionsToDelete); i++ {
idsToDelete = append(idsToDelete, DecisionsToDelete[i].ID)
if len(idsToDelete) == bulkSize {
if softDelete { var (
nbUpdates, err := c.Ent.Decision.Update().Where( nbUpdates int
decision.IDIn(idsToDelete...), err error
).SetUntil(time.Now().UTC()).Save(c.CTX) totalUpdates = 0
if err != nil { )
return totalUpdates, errors.Wrap(err, "soft delete decisions with provided filter")
} idsToDelete := make([]int, len(decisionsToDelete))
totalUpdates += nbUpdates for i, decision := range decisionsToDelete {
} else { idsToDelete[i] = decision.ID
nbUpdates, err := c.Ent.Decision.Delete().Where(
decision.IDIn(idsToDelete...),
).Exec(c.CTX)
if err != nil {
return totalUpdates, errors.Wrap(err, "hard delete decisions with provided filter")
}
totalUpdates += nbUpdates
}
idsToDelete = make([]int, 0, bulkSize)
}
} }
if len(idsToDelete) > 0 { for _, chunk := range slicetools.Chunks(idsToDelete, bulkSize) {
if softDelete { if softDelete {
nbUpdates, err := c.Ent.Decision.Update().Where( nbUpdates, err = c.Ent.Decision.Update().Where(
decision.IDIn(idsToDelete...), decision.IDIn(chunk...),
).SetUntil(time.Now().UTC()).Save(c.CTX) ).SetUntil(time.Now().UTC()).Save(c.CTX)
if err != nil { if err != nil {
return totalUpdates, errors.Wrap(err, "soft delete decisions with provided filter") return totalUpdates, fmt.Errorf("soft delete decisions with provided filter: %w", err)
} }
totalUpdates += nbUpdates
} else { } else {
nbUpdates, err := c.Ent.Decision.Delete().Where( nbUpdates, err = c.Ent.Decision.Delete().Where(
decision.IDIn(idsToDelete...), decision.IDIn(chunk...),
).Exec(c.CTX) ).Exec(c.CTX)
if err != nil { if err != nil {
return totalUpdates, errors.Wrap(err, "hard delete decisions with provided filter") return totalUpdates, fmt.Errorf("hard delete decisions with provided filter: %w", err)
} }
totalUpdates += nbUpdates
} }
totalUpdates += nbUpdates
} }
return totalUpdates, nil return totalUpdates, nil
} }
@ -601,6 +586,7 @@ func (c *Client) BulkDeleteDecisions(DecisionsToDelete []*ent.Decision, softDele
func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, error) { func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, error) {
toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX) toUpdate, err := c.Ent.Decision.Query().Where(decision.IDEQ(decisionID)).All(c.CTX)
// XXX: do we want 500 or 404 here?
if err != nil || len(toUpdate) == 0 { if err != nil || len(toUpdate) == 0 {
c.Log.Warningf("SoftDeleteDecisionByID : %v (nb soft deleted: %d)", err, len(toUpdate)) c.Log.Warningf("SoftDeleteDecisionByID : %v (nb soft deleted: %d)", err, len(toUpdate))
return 0, nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID) return 0, nil, errors.Wrapf(DeleteFail, "decision with id '%d' doesn't exist", decisionID)
@ -609,6 +595,7 @@ func (c *Client) SoftDeleteDecisionByID(decisionID int) (int, []*ent.Decision, e
if len(toUpdate) == 0 { if len(toUpdate) == 0 {
return 0, nil, ItemNotFound return 0, nil, ItemNotFound
} }
count, err := c.BulkDeleteDecisions(toUpdate, true) count, err := c.BulkDeleteDecisions(toUpdate, true)
return count, toUpdate, err return count, toUpdate, err
} }
@ -639,10 +626,7 @@ func (c *Client) CountDecisionsByValue(decisionValue string) (int, error) {
} }
func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) { func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Time) (int, error) {
var err error ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(decisionValue)
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz, count int
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(decisionValue)
if err != nil { if err != nil {
return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err) return 0, errors.Wrapf(InvalidIPOrRange, "unable to convert '%s' to int: %s", decisionValue, err)
@ -652,11 +636,13 @@ func (c *Client) CountDecisionsSinceByValue(decisionValue string, since time.Tim
decisions := c.Ent.Decision.Query().Where( decisions := c.Ent.Decision.Query().Where(
decision.CreatedAtGT(since), decision.CreatedAtGT(since),
) )
decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx) decisions, err = applyStartIpEndIpFilter(decisions, contains, ip_sz, start_ip, start_sfx, end_ip, end_sfx)
if err != nil { if err != nil {
return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter") return 0, errors.Wrapf(err, "fail to apply StartIpEndIpFilter")
} }
count, err = decisions.Count(c.CTX)
count, err := decisions.Count(c.CTX)
if err != nil { if err != nil {
return 0, errors.Wrapf(err, "fail to count decisions") return 0, errors.Wrapf(err, "fail to count decisions")
} }
@ -681,7 +667,10 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz
decision.IPSizeEQ(int64(ip_sz)), decision.IPSizeEQ(int64(ip_sz)),
)) ))
} }
} else if ip_sz == 16 { return decisions, nil
}
if ip_sz == 16 {
/*decision contains {start_ip,end_ip}*/ /*decision contains {start_ip,end_ip}*/
if contains { if contains {
decisions = decisions.Where(decision.And( decisions = decisions.Where(decision.And(
@ -733,9 +722,13 @@ func applyStartIpEndIpFilter(decisions *ent.DecisionQuery, contains bool, ip_sz
), ),
)) ))
} }
} else if ip_sz != 0 { return decisions, nil
}
if ip_sz != 0 {
return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz) return nil, errors.Wrapf(InvalidFilter, "unknown ip size %d", ip_sz)
} }
return decisions, nil return decisions, nil
} }