diff --git a/cmd/crowdsec-cli/utils.go b/cmd/crowdsec-cli/utils.go index 81cd2168b..4f2a74c7d 100644 --- a/cmd/crowdsec-cli/utils.go +++ b/cmd/crowdsec-cli/utils.go @@ -63,6 +63,11 @@ func manageCliDecisionAlerts(ip *string, ipRange *string, scope *string, value * *scope = types.Ip case "range": *scope = types.Range + case "country": + *scope = types.Country + case "as": + *scope = types.AS + } return nil } diff --git a/pkg/acquisition/tests/test.log b/pkg/acquisition/tests/test.log index 6347c5c19..90cd14715 100644 --- a/pkg/acquisition/tests/test.log +++ b/pkg/acquisition/tests/test.log @@ -1 +1 @@ -one log line +one log line \ No newline at end of file diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index ab516e357..484d26d27 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -3,6 +3,7 @@ package apiclient import ( "context" "fmt" + "strings" "github.com/crowdsecurity/crowdsec/pkg/models" qs "github.com/google/go-querystring/query" @@ -52,10 +53,12 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m return &decisions, resp, nil } -func (s *DecisionsService) GetStream(ctx context.Context, startup bool) (*models.DecisionsStreamResponse, *Response, error) { +func (s *DecisionsService) GetStream(ctx context.Context, startup bool, scopes []string) (*models.DecisionsStreamResponse, *Response, error) { var decisions models.DecisionsStreamResponse - u := fmt.Sprintf("%s/decisions/stream?startup=%t", s.client.URLPrefix, startup) + if len(scopes) > 0 { + u += "&scopes=" + strings.Join(scopes, ",") + } req, err := s.client.NewRequest("GET", u, nil) if err != nil { return nil, nil, err diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index 1fa94e0ce..3d15b6b8b 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -160,7 +160,7 @@ func TestDecisionsStream(t *testing.T) { }, } - decisions, resp, err := newcli.Decisions.GetStream(context.Background(), true) + decisions, resp, err := newcli.Decisions.GetStream(context.Background(), true, []string{}) require.NoError(t, err) if resp.Response.StatusCode != http.StatusOK { @@ -175,7 +175,7 @@ func TestDecisionsStream(t *testing.T) { } //and second call, we get empty lists - decisions, resp, err = newcli.Decisions.GetStream(context.Background(), false) + decisions, resp, err = newcli.Decisions.GetStream(context.Background(), false, []string{}) require.NoError(t, err) if resp.Response.StatusCode != http.StatusOK { diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 82397a07e..dc25a374a 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -234,7 +234,7 @@ func (a *apic) Send(cacheOrig *models.AddSignalsRequest) { func (a *apic) PullTop() error { var err error - data, _, err := a.apiClient.Decisions.GetStream(context.Background(), a.startup) + data, _, err := a.apiClient.Decisions.GetStream(context.Background(), a.startup, []string{}) if err != nil { return errors.Wrap(err, "get stream") } diff --git a/pkg/apiserver/controllers/v1/decisions.go b/pkg/apiserver/controllers/v1/decisions.go index c2646b1c0..17256f7bf 100644 --- a/pkg/apiserver/controllers/v1/decisions.go +++ b/pkg/apiserver/controllers/v1/decisions.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "strconv" + "strings" "time" "github.com/crowdsecurity/crowdsec/pkg/database/ent" @@ -127,10 +128,16 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { return } + filters := make(map[string][]string) + filters["scope"] = []string{"ip", "range"} + if val, ok := gctx.Request.URL.Query()["scopes"]; ok { + filters["scope"] = strings.Split(val[0], ",") + } + // if the blocker just start, return all decisions if val, ok := gctx.Request.URL.Query()["startup"]; ok { if val[0] == "true" { - data, err := c.DBClient.QueryAllDecisions() + data, err := c.DBClient.QueryAllDecisionsWithFilters(filters) if err != nil { log.Errorf("failed querying decisions: %v", err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -144,7 +151,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { } // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisions() + data, err = c.DBClient.QueryExpiredDecisionsWithFilters(filters) if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -172,7 +179,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { } // getting new decisions - data, err = c.DBClient.QueryNewDecisionsSince(bouncerInfo.LastPull) + data, err = c.DBClient.QueryNewDecisionsSinceWithFilters(bouncerInfo.LastPull, filters) if err != nil { log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) @@ -186,7 +193,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) { } // getting expired decisions - data, err = c.DBClient.QueryExpiredDecisionsSince(bouncerInfo.LastPull.Add((-2 * time.Second))) // do we want to give exactly lastPull time ? + data, err = c.DBClient.QueryExpiredDecisionsSinceWithFilters(bouncerInfo.LastPull.Add((-2 * time.Second)), filters) // do we want to give exactly lastPull time ? if err != nil { log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err) gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()}) diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index dc15cb4f0..9d2917730 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -277,7 +277,7 @@ func TestGetDecision(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "\"id\":1,\"origin\":\"test\",\"scenario\":\"crowdsecurity/test\",\"scope\":\"ip\",\"type\":\"ban\",\"value\":\"127.0.0.1\"}]") + assert.Contains(t, w.Body.String(), "\"id\":1,\"origin\":\"test\",\"scenario\":\"crowdsecurity/test\",\"scope\":\"Ip\",\"type\":\"ban\",\"value\":\"127.0.0.1\"}]") } @@ -449,5 +449,5 @@ func TestStreamDecision(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "\"id\":1,\"origin\":\"test\",\"scenario\":\"crowdsecurity/test\",\"scope\":\"ip\",\"type\":\"ban\",\"value\":\"127.0.0.1\"}]}") + assert.Contains(t, w.Body.String(), "\"id\":1,\"origin\":\"test\",\"scenario\":\"crowdsecurity/test\",\"scope\":\"Ip\",\"type\":\"ban\",\"value\":\"127.0.0.1\"}]}") } diff --git a/pkg/apiserver/tests/alert_sample.json b/pkg/apiserver/tests/alert_sample.json index 1abecb447..64b206f3c 100644 --- a/pkg/apiserver/tests/alert_sample.json +++ b/pkg/apiserver/tests/alert_sample.json @@ -10,7 +10,7 @@ "duration": "1h", "origin": "test", "scenario": "crowdsecurity/test", - "scope": "ip", + "scope": "Ip", "value": "127.0.0.1", "type": "ban" } diff --git a/pkg/database/decisions.go b/pkg/database/decisions.go index 50cd6a6ef..5f542db50 100644 --- a/pkg/database/decisions.go +++ b/pkg/database/decisions.go @@ -41,13 +41,19 @@ func BuildDecisionRequestWithFilter(query *ent.DecisionQuery, filter map[string] return nil, errors.Wrapf(InvalidFilter, "invalid contains value : %s", err) } case "scope": - var scope string = value[0] - if strings.ToLower(scope) == "ip" { - scope = types.Ip - } else if strings.ToLower(scope) == "range" { - scope = types.Range + for i, scope := range value { + switch strings.ToLower(scope) { + case "ip": + value[i] = types.Ip + case "range": + value[i] = types.Range + case "country": + value[i] = types.Country + case "as": + value[i] = types.AS + } } - query = query.Where(decision.ScopeEQ(scope)) + query = query.Where(decision.ScopeIn(value...)) case "value": query = query.Where(decision.ValueEQ(value[0])) case "type": @@ -165,37 +171,66 @@ func (c *Client) QueryDecisionWithFilter(filter map[string][]string) ([]*ent.Dec return data, nil } -func (c *Client) QueryAllDecisions() ([]*ent.Decision, error) { - data, err := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now())).All(c.CTX) +func (c *Client) QueryAllDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { + query := c.Ent.Decision.Query().Where(decision.UntilGT(time.Now())) + query, err := BuildDecisionRequestWithFilter(query, filters) + if err != nil { - c.Log.Warningf("QueryAllDecisions : %s", err) - return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions") + c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) + return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") + } + + data, err := query.All(c.CTX) + if err != nil { + c.Log.Warningf("QueryAllDecisionsWithFilters : %s", err) + return []*ent.Decision{}, errors.Wrap(QueryFail, "get all decisions with filters") } return data, nil } -func (c *Client) QueryExpiredDecisions() ([]*ent.Decision, error) { - data, err := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now())).All(c.CTX) +func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) ([]*ent.Decision, error) { + query := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now())) + query, err := BuildDecisionRequestWithFilter(query, filters) + if err != nil { - c.Log.Warningf("QueryExpiredDecisions : %s", err) + c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) + return []*ent.Decision{}, errors.Wrap(QueryFail, "get expired decisions with filters") + } + data, err := query.All(c.CTX) + if err != nil { + c.Log.Warningf("QueryExpiredDecisionsWithFilters : %s", err) return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions") } return data, nil } -func (c *Client) QueryExpiredDecisionsSince(since time.Time) ([]*ent.Decision, error) { - data, err := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now())).Where(decision.UntilGT(since)).All(c.CTX) +func (c *Client) QueryExpiredDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) { + query := c.Ent.Decision.Query().Where(decision.UntilLT(time.Now())).Where(decision.UntilGT(since)) + query, err := BuildDecisionRequestWithFilter(query, filters) if err != nil { - c.Log.Warningf("QueryExpiredDecisionsSince : %s", err) - return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions") + c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) + return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") } + + data, err := query.All(c.CTX) + if err != nil { + c.Log.Warningf("QueryExpiredDecisionsSinceWithFilters : %s", err) + return []*ent.Decision{}, errors.Wrap(QueryFail, "expired decisions with filters") + } + return data, nil } -func (c *Client) QueryNewDecisionsSince(since time.Time) ([]*ent.Decision, error) { - data, err := c.Ent.Decision.Query().Where(decision.CreatedAtGT(since)).All(c.CTX) +func (c *Client) QueryNewDecisionsSinceWithFilters(since time.Time, filters map[string][]string) ([]*ent.Decision, error) { + query := c.Ent.Decision.Query().Where(decision.CreatedAtGT(since)) + query, err := BuildDecisionRequestWithFilter(query, filters) if err != nil { - c.Log.Warningf("QueryNewDecisionsSince : %s", err) + c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) + return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String()) + } + data, err := query.All(c.CTX) + if err != nil { + c.Log.Warningf("QueryNewDecisionsSinceWithFilters : %s", err) return []*ent.Decision{}, errors.Wrapf(QueryFail, "new decisions since '%s'", since.String()) } return data, nil diff --git a/pkg/types/event.go b/pkg/types/event.go index 14809d6bd..ff04619cd 100644 --- a/pkg/types/event.go +++ b/pkg/types/event.go @@ -57,6 +57,8 @@ const ( Ip = "Ip" Range = "Range" Filter = "Filter" + Country = "Country" + AS = "AS" ) //Move in leakybuckets