Compare commits

...

4 commits

Author SHA1 Message Date
Sebastien Blot 319957e134
try to use a LEFT JOIN 2022-07-22 08:59:56 +02:00
alteredCoder 42d5dd8e03 try this 2022-07-21 18:46:24 +02:00
Sebastien Blot ecece8f4a5
comment out tests that fail for now 2022-07-21 17:21:12 +02:00
Sebastien Blot aabd258f2b
use not exists to get expired decisions 2022-07-21 17:20:37 +02:00
9 changed files with 291 additions and 36 deletions

View file

@ -47,6 +47,7 @@ func LoadTestConfig() csconfig.Config {
Type: "sqlite",
DbPath: filepath.Join(tempDir, "ent"),
Flush: &flushConfig,
//LogLevel: &log.AllLevels[log.DebugLevel],
}
apiServerConfig := csconfig.LocalApiServerCfg{
ListenURI: "http://127.0.0.1:8080",

View file

@ -281,7 +281,6 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, int64(2), decisions["new"][0].ID)
assert.Equal(t, "test", *decisions["new"][0].Origin)
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
// We delete another decision, yet don't receive it in stream, since there's another decision on same IP
w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
@ -1012,6 +1011,104 @@ func TestStreamDecision(t *testing.T) {
NewChecks: []DecisionCheck{},
},
},
"test startup with scenarios containing": {
{
TestName: "get stream",
Method: "GET",
Route: "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf",
CheckCodeOnly: false,
Code: 200,
LenNew: 2,
LenDeleted: 0,
AuthType: APIKEY,
DelChecks: []DecisionCheck{},
NewChecks: []DecisionCheck{
{
ID: int64(2),
Origin: "another_origin",
Scenario: "crowdsecurity/ssh_bf",
Value: "127.0.0.1",
Duration: "2h59",
Type: "ban",
},
{
ID: int64(5),
Origin: "test",
Scenario: "crowdsecurity/ssh_bf",
Value: "127.0.0.2",
Duration: "2h59",
Type: "ban",
},
},
},
{
TestName: "delete decisions 3 (127.0.0.1)",
Method: "DELETE",
Route: "/v1/decisions/3",
CheckCodeOnly: true,
Code: 200,
LenNew: 0,
LenDeleted: 0,
AuthType: PASSWORD,
DelChecks: []DecisionCheck{},
NewChecks: []DecisionCheck{},
},
{
TestName: "check that 127.0.0.1 is not in deleted IP",
Method: "GET",
Route: "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf",
CheckCodeOnly: false,
Code: 200,
LenNew: 2,
LenDeleted: 0,
AuthType: APIKEY,
DelChecks: []DecisionCheck{},
NewChecks: []DecisionCheck{},
},
{
TestName: "delete decisions 2 (127.0.0.1)",
Method: "DELETE",
Route: "/v1/decisions/2",
CheckCodeOnly: true,
Code: 200,
LenNew: 0,
LenDeleted: 0,
AuthType: PASSWORD,
DelChecks: []DecisionCheck{},
NewChecks: []DecisionCheck{},
},
{
TestName: "check that 127.0.0.1 is deleted (decision for ssh_bf was with ID 2)",
Method: "GET",
Route: "/v1/decisions/stream?startup=true&scenarios_containing=ssh_bf",
CheckCodeOnly: false,
Code: 200,
LenNew: 1,
LenDeleted: 1,
AuthType: APIKEY,
DelChecks: []DecisionCheck{
{
ID: int64(2),
Origin: "another_origin",
Scenario: "crowdsecurity/ssh_bf",
Value: "127.0.0.1",
Duration: "-",
Type: "ban",
},
},
NewChecks: []DecisionCheck{
{
ID: int64(5),
Origin: "test",
Scenario: "crowdsecurity/ssh_bf",
Value: "127.0.0.2",
Duration: "2h59",
Type: "ban",
},
},
},
},
"test with scenarios containing": {
{
TestName: "get stream",

View file

@ -43,7 +43,7 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string]
} else {
query = query.Where(decision.SimulatedEQ(false))
}
t := sql.Table(decision.Table)
t := sql.Table(decision.Table).As("t1")
joinPredicate := make([]*sql.Predicate, 0)
for param, value := range filter {
switch param {
@ -199,33 +199,55 @@ func (c *Client) QueryDecisionCountByScenario(filters map[string][]string) ([]*D
func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) {
now := time.Now().UTC()
query := c.Ent.Decision.Query().Where(
decision.UntilLT(time.Now().UTC()),
decision.UntilLTE(now),
)
query, predicates, err := BuildDecisionRequestWithFilter(query, filters)
if err != nil {
c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err)
return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters")
}
query = query.Where(func(s *sql.Selector) {
/*query = query.Where(func(s *sql.Selector) {
t := sql.Table(decision.Table).As("t1")
subQuery := sql.Select(t.C(decision.FieldValue)).From(t).Where(sql.GT(t.C(decision.FieldUntil), now))
for _, predicate := range predicates {
subQuery.Where(predicate)
subquery := sql.Select(s.C(decision.FieldValue)).From(t)
for _, pred := range predicates {
subquery.Where(pred)
}
subQuery.Where(sql.And(
sql.ColumnsEQ(t.C(decision.FieldType), s.C(decision.FieldType)),
sql.ColumnsEQ(t.C(decision.FieldScope), s.C(decision.FieldScope)),
))
s.Where(
sql.NotIn(
s.C(decision.FieldValue),
subQuery,
subquery = subquery.Where(
sql.And(
sql.ColumnsEQ(s.C(decision.FieldScope), t.C(decision.FieldScope)),
sql.ColumnsEQ(s.C(decision.FieldType), t.C(decision.FieldType)),
sql.ColumnsEQ(s.C(decision.FieldValue), t.C(decision.FieldValue)),
sql.GT(t.C(decision.FieldUntil), now),
),
)
s.Where(sql.NotExists(subquery))
})
data, err := query.Order(ent.Asc(decision.FieldValue), ent.Desc(decision.FieldUntil)).All(c.CTX)
data, err := query.Order(ent.Asc(decision.FieldValue), ent.Desc(decision.FieldUntil)).All(c.CTX)*/
query.Modify(func(s *sql.Selector) {
t := sql.Table(decision.Table).As("t1")
p := []*sql.Predicate{
sql.ColumnsEQ(s.C(decision.FieldScope), t.C(decision.FieldScope)),
sql.ColumnsEQ(s.C(decision.FieldType), t.C(decision.FieldType)),
sql.ColumnsEQ(s.C(decision.FieldValue), t.C(decision.FieldValue)),
sql.GTE(t.C(decision.FieldUntil), now),
}
p = append(p, predicates...)
s.LeftJoin(t).
OnP(
sql.And(
p...,
)).
GroupBy(s.C(decision.FieldValue)).
Where(sql.IsNull(t.C(decision.FieldValue)))
})
data, err := query.All(c.CTX)
if err != nil {
c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err)
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions")
@ -242,37 +264,40 @@ func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters
now := time.Now().UTC()
query := c.Ent.Decision.Query().Where(
decision.UntilGT(since),
decision.UntilGTE(since),
decision.UntilLTE(now),
)
query, _, err := BuildDecisionRequestWithFilter(query, filters)
query, predicates, err := BuildDecisionRequestWithFilter(query, filters)
if err != nil {
c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err)
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters")
}
data, err := query.Order(ent.Asc(decision.FieldValue), ent.Asc(decision.FieldUntil)).All(c.CTX)
query.Modify(func(s *sql.Selector) {
t := sql.Table(decision.Table).As("t1")
p := []*sql.Predicate{
sql.ColumnsEQ(s.C(decision.FieldScope), t.C(decision.FieldScope)),
sql.ColumnsEQ(s.C(decision.FieldType), t.C(decision.FieldType)),
sql.ColumnsEQ(s.C(decision.FieldValue), t.C(decision.FieldValue)),
sql.GTE(t.C(decision.FieldUntil), now),
}
p = append(p, predicates...)
s.LeftJoin(t).
OnP(
sql.And(
p...,
)).
GroupBy(s.C(decision.FieldValue)).
Where(sql.IsNull(t.C(decision.FieldValue)))
})
data, err := query.All(c.CTX)
if err != nil {
c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err)
return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters")
}
ret := make([]*ent.Decision, 0)
deletedDecisions := make(map[string]*ent.Decision)
for _, d := range data {
key := fmt.Sprintf("%s:%s:%s", d.Scope, d.Type, d.Value)
if d.Until.Before(now) {
deletedDecisions[key] = d
}
if d.Until.After(now) {
delete(deletedDecisions, key)
}
}
for _, d := range deletedDecisions {
ret = append(ret, d)
}
return ret, nil
return data, nil
}
func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) {

View file

@ -35,6 +35,7 @@ type AlertQuery struct {
withEvents *EventQuery
withMetas *MetaQuery
withFKs bool
modifiers []func(s *sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@ -487,6 +488,9 @@ func (aq *AlertQuery) sqlAll(ctx context.Context) ([]*Alert, error) {
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(aq.modifiers) > 0 {
_spec.Modifiers = aq.modifiers
}
if err := sqlgraph.QueryNodes(ctx, aq.driver, _spec); err != nil {
return nil, err
}
@ -615,6 +619,9 @@ func (aq *AlertQuery) sqlAll(ctx context.Context) ([]*Alert, error) {
func (aq *AlertQuery) sqlCount(ctx context.Context) (int, error) {
_spec := aq.querySpec()
if len(aq.modifiers) > 0 {
_spec.Modifiers = aq.modifiers
}
_spec.Node.Columns = aq.fields
if len(aq.fields) > 0 {
_spec.Unique = aq.unique != nil && *aq.unique
@ -693,6 +700,9 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector {
if aq.unique != nil && *aq.unique {
selector.Distinct()
}
for _, m := range aq.modifiers {
m(selector)
}
for _, p := range aq.predicates {
p(selector)
}
@ -710,6 +720,12 @@ func (aq *AlertQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (aq *AlertQuery) Modify(modifiers ...func(s *sql.Selector)) *AlertSelect {
aq.modifiers = append(aq.modifiers, modifiers...)
return aq.Select()
}
// AlertGroupBy is the group-by builder for Alert entities.
type AlertGroupBy struct {
config
@ -1197,3 +1213,9 @@ func (as *AlertSelect) sqlScan(ctx context.Context, v interface{}) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (as *AlertSelect) Modify(modifiers ...func(s *sql.Selector)) *AlertSelect {
as.modifiers = append(as.modifiers, modifiers...)
return as
}

View file

@ -24,6 +24,7 @@ type BouncerQuery struct {
order []OrderFunc
fields []string
predicates []predicate.Bouncer
modifiers []func(s *sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@ -326,6 +327,9 @@ func (bq *BouncerQuery) sqlAll(ctx context.Context) ([]*Bouncer, error) {
node := nodes[len(nodes)-1]
return node.assignValues(columns, values)
}
if len(bq.modifiers) > 0 {
_spec.Modifiers = bq.modifiers
}
if err := sqlgraph.QueryNodes(ctx, bq.driver, _spec); err != nil {
return nil, err
}
@ -337,6 +341,9 @@ func (bq *BouncerQuery) sqlAll(ctx context.Context) ([]*Bouncer, error) {
func (bq *BouncerQuery) sqlCount(ctx context.Context) (int, error) {
_spec := bq.querySpec()
if len(bq.modifiers) > 0 {
_spec.Modifiers = bq.modifiers
}
_spec.Node.Columns = bq.fields
if len(bq.fields) > 0 {
_spec.Unique = bq.unique != nil && *bq.unique
@ -415,6 +422,9 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector {
if bq.unique != nil && *bq.unique {
selector.Distinct()
}
for _, m := range bq.modifiers {
m(selector)
}
for _, p := range bq.predicates {
p(selector)
}
@ -432,6 +442,12 @@ func (bq *BouncerQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (bq *BouncerQuery) Modify(modifiers ...func(s *sql.Selector)) *BouncerSelect {
bq.modifiers = append(bq.modifiers, modifiers...)
return bq.Select()
}
// BouncerGroupBy is the group-by builder for Bouncer entities.
type BouncerGroupBy struct {
config
@ -919,3 +935,9 @@ func (bs *BouncerSelect) sqlScan(ctx context.Context, v interface{}) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (bs *BouncerSelect) Modify(modifiers ...func(s *sql.Selector)) *BouncerSelect {
bs.modifiers = append(bs.modifiers, modifiers...)
return bs
}

View file

@ -28,6 +28,7 @@ type DecisionQuery struct {
// eager-loading edges.
withOwner *AlertQuery
withFKs bool
modifiers []func(s *sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@ -375,6 +376,9 @@ func (dq *DecisionQuery) sqlAll(ctx context.Context) ([]*Decision, error) {
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(dq.modifiers) > 0 {
_spec.Modifiers = dq.modifiers
}
if err := sqlgraph.QueryNodes(ctx, dq.driver, _spec); err != nil {
return nil, err
}
@ -416,6 +420,9 @@ func (dq *DecisionQuery) sqlAll(ctx context.Context) ([]*Decision, error) {
func (dq *DecisionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := dq.querySpec()
if len(dq.modifiers) > 0 {
_spec.Modifiers = dq.modifiers
}
_spec.Node.Columns = dq.fields
if len(dq.fields) > 0 {
_spec.Unique = dq.unique != nil && *dq.unique
@ -494,6 +501,9 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
if dq.unique != nil && *dq.unique {
selector.Distinct()
}
for _, m := range dq.modifiers {
m(selector)
}
for _, p := range dq.predicates {
p(selector)
}
@ -511,6 +521,12 @@ func (dq *DecisionQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (dq *DecisionQuery) Modify(modifiers ...func(s *sql.Selector)) *DecisionSelect {
dq.modifiers = append(dq.modifiers, modifiers...)
return dq.Select()
}
// DecisionGroupBy is the group-by builder for Decision entities.
type DecisionGroupBy struct {
config
@ -998,3 +1014,9 @@ func (ds *DecisionSelect) sqlScan(ctx context.Context, v interface{}) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (ds *DecisionSelect) Modify(modifiers ...func(s *sql.Selector)) *DecisionSelect {
ds.modifiers = append(ds.modifiers, modifiers...)
return ds
}

View file

@ -28,6 +28,7 @@ type EventQuery struct {
// eager-loading edges.
withOwner *AlertQuery
withFKs bool
modifiers []func(s *sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@ -375,6 +376,9 @@ func (eq *EventQuery) sqlAll(ctx context.Context) ([]*Event, error) {
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(eq.modifiers) > 0 {
_spec.Modifiers = eq.modifiers
}
if err := sqlgraph.QueryNodes(ctx, eq.driver, _spec); err != nil {
return nil, err
}
@ -416,6 +420,9 @@ func (eq *EventQuery) sqlAll(ctx context.Context) ([]*Event, error) {
func (eq *EventQuery) sqlCount(ctx context.Context) (int, error) {
_spec := eq.querySpec()
if len(eq.modifiers) > 0 {
_spec.Modifiers = eq.modifiers
}
_spec.Node.Columns = eq.fields
if len(eq.fields) > 0 {
_spec.Unique = eq.unique != nil && *eq.unique
@ -494,6 +501,9 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector {
if eq.unique != nil && *eq.unique {
selector.Distinct()
}
for _, m := range eq.modifiers {
m(selector)
}
for _, p := range eq.predicates {
p(selector)
}
@ -511,6 +521,12 @@ func (eq *EventQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (eq *EventQuery) Modify(modifiers ...func(s *sql.Selector)) *EventSelect {
eq.modifiers = append(eq.modifiers, modifiers...)
return eq.Select()
}
// EventGroupBy is the group-by builder for Event entities.
type EventGroupBy struct {
config
@ -998,3 +1014,9 @@ func (es *EventSelect) sqlScan(ctx context.Context, v interface{}) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (es *EventSelect) Modify(modifiers ...func(s *sql.Selector)) *EventSelect {
es.modifiers = append(es.modifiers, modifiers...)
return es
}

View file

@ -28,6 +28,7 @@ type MachineQuery struct {
predicates []predicate.Machine
// eager-loading edges.
withAlerts *AlertQuery
modifiers []func(s *sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@ -368,6 +369,9 @@ func (mq *MachineQuery) sqlAll(ctx context.Context) ([]*Machine, error) {
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(mq.modifiers) > 0 {
_spec.Modifiers = mq.modifiers
}
if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil {
return nil, err
}
@ -409,6 +413,9 @@ func (mq *MachineQuery) sqlAll(ctx context.Context) ([]*Machine, error) {
func (mq *MachineQuery) sqlCount(ctx context.Context) (int, error) {
_spec := mq.querySpec()
if len(mq.modifiers) > 0 {
_spec.Modifiers = mq.modifiers
}
_spec.Node.Columns = mq.fields
if len(mq.fields) > 0 {
_spec.Unique = mq.unique != nil && *mq.unique
@ -487,6 +494,9 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector {
if mq.unique != nil && *mq.unique {
selector.Distinct()
}
for _, m := range mq.modifiers {
m(selector)
}
for _, p := range mq.predicates {
p(selector)
}
@ -504,6 +514,12 @@ func (mq *MachineQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (mq *MachineQuery) Modify(modifiers ...func(s *sql.Selector)) *MachineSelect {
mq.modifiers = append(mq.modifiers, modifiers...)
return mq.Select()
}
// MachineGroupBy is the group-by builder for Machine entities.
type MachineGroupBy struct {
config
@ -991,3 +1007,9 @@ func (ms *MachineSelect) sqlScan(ctx context.Context, v interface{}) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (ms *MachineSelect) Modify(modifiers ...func(s *sql.Selector)) *MachineSelect {
ms.modifiers = append(ms.modifiers, modifiers...)
return ms
}

View file

@ -28,6 +28,7 @@ type MetaQuery struct {
// eager-loading edges.
withOwner *AlertQuery
withFKs bool
modifiers []func(s *sql.Selector)
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
@ -375,6 +376,9 @@ func (mq *MetaQuery) sqlAll(ctx context.Context) ([]*Meta, error) {
node.Edges.loadedTypes = loadedTypes
return node.assignValues(columns, values)
}
if len(mq.modifiers) > 0 {
_spec.Modifiers = mq.modifiers
}
if err := sqlgraph.QueryNodes(ctx, mq.driver, _spec); err != nil {
return nil, err
}
@ -416,6 +420,9 @@ func (mq *MetaQuery) sqlAll(ctx context.Context) ([]*Meta, error) {
func (mq *MetaQuery) sqlCount(ctx context.Context) (int, error) {
_spec := mq.querySpec()
if len(mq.modifiers) > 0 {
_spec.Modifiers = mq.modifiers
}
_spec.Node.Columns = mq.fields
if len(mq.fields) > 0 {
_spec.Unique = mq.unique != nil && *mq.unique
@ -494,6 +501,9 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector {
if mq.unique != nil && *mq.unique {
selector.Distinct()
}
for _, m := range mq.modifiers {
m(selector)
}
for _, p := range mq.predicates {
p(selector)
}
@ -511,6 +521,12 @@ func (mq *MetaQuery) sqlQuery(ctx context.Context) *sql.Selector {
return selector
}
// Modify adds a query modifier for attaching custom logic to queries.
func (mq *MetaQuery) Modify(modifiers ...func(s *sql.Selector)) *MetaSelect {
mq.modifiers = append(mq.modifiers, modifiers...)
return mq.Select()
}
// MetaGroupBy is the group-by builder for Meta entities.
type MetaGroupBy struct {
config
@ -998,3 +1014,9 @@ func (ms *MetaSelect) sqlScan(ctx context.Context, v interface{}) error {
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// Modify adds a query modifier for attaching custom logic to queries.
func (ms *MetaSelect) Modify(modifiers ...func(s *sql.Selector)) *MetaSelect {
ms.modifiers = append(ms.modifiers, modifiers...)
return ms
}