diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index 4513ba3e7..4ce3b162f 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -224,7 +224,7 @@ func TestDeleteDecisionByID(t *testing.T) { decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, 200, code) - //assert.Equal(t, 0, len(decisions["deleted"])) + assert.Equal(t, 0, len(decisions["deleted"])) assert.Equal(t, 1, len(decisions["new"])) } @@ -276,12 +276,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, nil, err) assert.Equal(t, 200, code) - //assert.Equal(t, 0, len(decisions["deleted"])) + assert.Equal(t, 0, len(decisions["deleted"])) assert.Equal(t, 1, len(decisions["new"])) assert.Equal(t, int64(2), decisions["new"][0].ID) assert.Equal(t, "test", *decisions["new"][0].Origin) assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) - // We delete another decision, yet don't receive it in stream, since there's another decision on same IP w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) @@ -291,7 +290,7 @@ func TestStreamStartDecisionDedup(t *testing.T) { decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, nil, err) assert.Equal(t, 200, code) - //assert.Equal(t, 0, len(decisions["deleted"])) + assert.Equal(t, 0, len(decisions["deleted"])) assert.Equal(t, 1, len(decisions["new"])) assert.Equal(t, int64(1), decisions["new"][0].ID) assert.Equal(t, "test", *decisions["new"][0].Origin) @@ -1012,6 +1011,104 @@ func TestStreamDecision(t *testing.T) { NewChecks: []DecisionCheck{}, }, }, + "test startup with scenarios containing": { + { + TestName: "get stream", + Method: "GET", + Route: "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf", + CheckCodeOnly: false, + Code: 200, + LenNew: 2, + LenDeleted: 0, + AuthType: APIKEY, + DelChecks: []DecisionCheck{}, + NewChecks: []DecisionCheck{ + { + ID: int64(2), + Origin: "another_origin", + Scenario: "crowdsecurity/ssh_bf", + Value: "127.0.0.1", + Duration: "2h59", + Type: "ban", + }, + { + ID: int64(5), + Origin: "test", + Scenario: "crowdsecurity/ssh_bf", + Value: "127.0.0.2", + Duration: "2h59", + Type: "ban", + }, + }, + }, + { + TestName: "delete decisions 3 (127.0.0.1)", + Method: "DELETE", + Route: "/v1/decisions/3", + CheckCodeOnly: true, + Code: 200, + LenNew: 0, + LenDeleted: 0, + AuthType: PASSWORD, + DelChecks: []DecisionCheck{}, + NewChecks: []DecisionCheck{}, + }, + { + TestName: "check that 127.0.0.1 is not in deleted IP", + Method: "GET", + Route: "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf", + CheckCodeOnly: false, + Code: 200, + LenNew: 2, + LenDeleted: 0, + AuthType: APIKEY, + DelChecks: []DecisionCheck{}, + NewChecks: []DecisionCheck{}, + }, + { + TestName: "delete decisions 2 (127.0.0.1)", + Method: "DELETE", + Route: "/v1/decisions/2", + CheckCodeOnly: true, + Code: 200, + LenNew: 0, + LenDeleted: 0, + AuthType: PASSWORD, + DelChecks: []DecisionCheck{}, + NewChecks: []DecisionCheck{}, + }, + { + TestName: "check that 127.0.0.1 is deleted (decision for ssh_bf was with ID 2)", + Method: "GET", + Route: "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf", + CheckCodeOnly: false, + Code: 200, + LenNew: 1, + LenDeleted: 1, + AuthType: APIKEY, + DelChecks: []DecisionCheck{ + { + ID: int64(2), + Origin: "another_origin", + Scenario: "crowdsecurity/ssh_bf", + Value: "127.0.0.1", + Duration: "-", + + Type: "ban", + }, + }, + NewChecks: []DecisionCheck{ + { + ID: int64(5), + Origin: "test", + Scenario: "crowdsecurity/ssh_bf", + Value: "127.0.0.2", + Duration: "2h59", + Type: "ban", + }, + }, + }, + }, "test with scenarios containing": { { TestName: "get stream", @@ -1344,7 +1441,7 @@ func runTest(lapi LAPI, test DecisionTest, t *testing.T) { } decisions, _, err := readDecisionsStreamResp(w) assert.Equal(t, nil, err) - //assert.Equal(t, test.LenDeleted, len(decisions["deleted"]), fmt.Sprintf("'%s': len(deleted)", test.TestName)) + assert.Equal(t, test.LenDeleted, len(decisions["deleted"]), fmt.Sprintf("'%s': len(deleted)", test.TestName)) assert.Equal(t, test.LenNew, len(decisions["new"]), fmt.Sprintf("'%s': len(new)", test.TestName)) for i, check := range test.NewChecks { diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index 824ea894b..589feefc0 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -43,7 +43,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] } else { query = query.Where(decision.SimulatedEQ(false)) } - t := sql.Table(decision.Table) + t := sql.Table(decision.Table).As("t1") joinPredicate := make([]*sql.Predicate, 0) for param, value := range filter { switch param { @@ -199,9 +199,9 @@ func (c *Client) QueryDecisionCountByScenario(filters map[string][]string) ([]*D func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { now := time.Now().UTC() query := c.Ent.Decision.Query().Where( - decision.UntilLT(time.Now().UTC()), + decision.UntilLTE(now), ) - query, _, err := BuildDecisionRequestWithFilter(query, filters) + query, predicates, err := BuildDecisionRequestWithFilter(query, filters) if err != nil { c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters") @@ -210,12 +210,17 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ( query = query.Where(func(s *sql.Selector) { t := sql.Table(decision.Table).As("t1") - subquery := sql.Select().From(t).Where( + subquery := sql.Select(s.C(decision.FieldValue)).From(t) + for _, pred := range predicates { + subquery.Where(pred) + } + + subquery = subquery.Where( sql.And( - sql.EQ(s.C(decision.FieldScope), t.C(decision.FieldScope)), - sql.EQ(s.C(decision.FieldType), t.C(decision.FieldType)), - sql.EQ(s.C(decision.FieldValue), t.C(decision.FieldValue)), - sql.GT(s.C(decision.FieldUntil), now), + sql.ColumnsEQ(s.C(decision.FieldScope), t.C(decision.FieldScope)), + sql.ColumnsEQ(s.C(decision.FieldType), t.C(decision.FieldType)), + sql.ColumnsEQ(s.C(decision.FieldValue), t.C(decision.FieldValue)), + sql.GT(t.C(decision.FieldUntil), now), ), ) s.Where(sql.NotExists(subquery)) @@ -238,9 +243,9 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters now := time.Now().UTC() query := c.Ent.Decision.Query().Where( - decision.UntilGT(since), + decision.UntilGTE(since), ) - query, _, err := BuildDecisionRequestWithFilter(query, filters) + query, predicates, err := BuildDecisionRequestWithFilter(query, filters) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") @@ -249,40 +254,29 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters query = query.Where(func(s *sql.Selector) { t := sql.Table(decision.Table).As("t1") - subquery := sql.Select().From(t).Where( + subquery := sql.Select(s.C(decision.FieldValue)).From(t) + for _, pred := range predicates { + subquery.Where(pred) + } + + subquery = subquery.Where( sql.And( - sql.EQ(s.C(decision.FieldScope), t.C(decision.FieldScope)), - sql.EQ(s.C(decision.FieldType), t.C(decision.FieldType)), - sql.EQ(s.C(decision.FieldValue), t.C(decision.FieldValue)), - sql.GT(s.C(decision.FieldUntil), now), + sql.ColumnsEQ(s.C(decision.FieldScope), t.C(decision.FieldScope)), + sql.ColumnsEQ(s.C(decision.FieldType), t.C(decision.FieldType)), + sql.ColumnsEQ(s.C(decision.FieldValue), t.C(decision.FieldValue)), + sql.GT(t.C(decision.FieldUntil), now), ), ) s.Where(sql.NotExists(subquery)) }) - data, err := query.Order(ent.Asc(decision.FieldValue), ent.Asc(decision.FieldUntil)).All(c.CTX) + data, err := query.Order(ent.Asc(decision.FieldValue), ent.Desc(decision.FieldUntil)).All(c.CTX) if err != nil { c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") } - ret := make([]*ent.Decision, 0) - deletedDecisions := make(map[string]*ent.Decision) - for _, d := range data { - key := fmt.Sprintf("%s:%s:%s", d.Scope, d.Type, d.Value) - if d.Until.Before(now) { - deletedDecisions[key] = d - } - if d.Until.After(now) { - delete(deletedDecisions, key) - } - } - - for _, d := range deletedDecisions { - ret = append(ret, d) - } - - return ret, nil + return data, nil } func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) {