From ecb32d74c64afbdb4e29bbc3f1680717fad31a49 Mon Sep 17 00:00:00 2001 From: Cristian Nitescu Date: Mon, 13 Feb 2023 15:06:14 +0100 Subject: [PATCH] optimize blocklist fetch (#2039) --- pkg/apiclient/client.go | 2 +- pkg/apiclient/decisions_service.go | 60 ++++++--- pkg/apiclient/decisions_service_test.go | 34 +++++- pkg/apiserver/apic.go | 59 ++++++++- pkg/apiserver/apic_test.go | 154 +++++++++++++++++++++++- 5 files changed, 280 insertions(+), 29 deletions(-) diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index 52a7c6f11..a7db1ed33 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -158,7 +158,7 @@ func newResponse(r *http.Response) *Response { } func CheckResponse(r *http.Response) error { - if c := r.StatusCode; 200 <= c && c <= 299 { + if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 { return nil } errorResponse := &ErrorResponse{} diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index f7844fb8a..4cc33fe7b 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -4,7 +4,6 @@ import ( "bufio" "context" "fmt" - "io" "net/http" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -150,29 +149,58 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m return &v2Decisions, resp, nil } -func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink) ([]*models.Decision, error) { +func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp *string) ([]*models.Decision, bool, error) { if blocklist.URL == nil { - return nil, errors.New("blocklist URL is nil") + return nil, false, errors.New("blocklist URL is nil") } log.Debugf("Fetching blocklist %s", *blocklist.URL) - req, err := s.client.NewRequest(http.MethodGet, *blocklist.URL, nil) + client := http.Client{} + req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil) if err != nil { - return nil, err + return nil, false, err } - pr, pw := io.Pipe() - defer pr.Close() - go func() { - defer pw.Close() - _, err = s.client.Do(ctx, req, pw) - if err != nil { - log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err) + if lastPullTimestamp != nil { + req.Header.Set("If-Modified-Since", *lastPullTimestamp) + } + req = req.WithContext(ctx) + log.Debugf("[URL] %s %s", req.Method, req.URL) + // we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc + resp, err := client.Do(req) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + if err != nil { + // If we got an error, and the context has been canceled, + // the context's error is probably more useful. + select { + case <-ctx.Done(): + return nil, false, ctx.Err() + default: } - }() + + // If the error type is *url.Error, sanitize its URL before returning. + log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err) + return nil, false, err + } + + if resp.StatusCode == http.StatusNotModified { + if lastPullTimestamp != nil { + log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, *lastPullTimestamp) + } else { + log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL) + } + return nil, false, nil + } + if resp.StatusCode != http.StatusOK { + log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL) + return nil, false, nil + } decisions := make([]*models.Decision, 0) - scanner := bufio.NewScanner(pr) + scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { decision := scanner.Text() decisions = append(decisions, &models.Decision{ @@ -185,7 +213,9 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl }) } - return decisions, nil + // here the upper go routine is finished because scanner.Scan() is blocking until pw.Close() is called + // so it's safe to use the isModified variable here + return decisions, true, nil } func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOpts) (*models.DecisionsStreamResponse, *Response, error) { diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index 92a9c448a..35d819d5a 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -11,6 +11,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/modelscapi" + "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -377,9 +378,11 @@ func TestDecisionsFromBlocklist(t *testing.T) { defer teardown() mux.HandleFunc("/blocklist", func(w http.ResponseWriter, r *http.Request) { - - assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu") testMethod(t, r, http.MethodGet) + if r.Header.Get("If-Modified-Since") == "Sun, 01 Jan 2023 01:01:01 GMT" { + w.WriteHeader(http.StatusNotModified) + return + } if r.Method == http.MethodGet { w.WriteHeader(http.StatusOK) w.Write([]byte("1.2.3.4\r\n1.2.3.5")) @@ -407,7 +410,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { tnameBlocklist := "blocklist1" tremediationBlocklist := "ban" tscopeBlocklist := "ip" - turlBlocklist := "/v3/blocklist" + turlBlocklist := urlx + "/v3/blocklist" torigin := "lists" expected := []*models.Decision{ { @@ -427,14 +430,15 @@ func TestDecisionsFromBlocklist(t *testing.T) { Origin: &torigin, }, } - decisions, err := newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ + decisions, isModified, err := newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ URL: &turlBlocklist, Scope: &tscopeBlocklist, Remediation: &tremediationBlocklist, Name: &tnameBlocklist, Duration: &tdurationBlocklist, - }) + }, nil) require.NoError(t, err) + assert.True(t, isModified) log.Infof("decision1: %+v", decisions[0]) log.Infof("expected1: %+v", expected[0]) @@ -448,6 +452,26 @@ func TestDecisionsFromBlocklist(t *testing.T) { if !reflect.DeepEqual(decisions, expected) { t.Fatalf("returned %+v, want %+v", decisions, expected) } + + // test cache control + _, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ + URL: &turlBlocklist, + Scope: &tscopeBlocklist, + Remediation: &tremediationBlocklist, + Name: &tnameBlocklist, + Duration: &tdurationBlocklist, + }, types.StrPtr("Sun, 01 Jan 2023 01:01:01 GMT")) + require.NoError(t, err) + assert.False(t, isModified) + _, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ + URL: &turlBlocklist, + Scope: &tscopeBlocklist, + Remediation: &tremediationBlocklist, + Name: &tnameBlocklist, + Duration: &tdurationBlocklist, + }, types.StrPtr("Mon, 02 Jan 2023 01:01:01 GMT")) + require.NoError(t, err) + assert.True(t, isModified) } func TestDeleteDecisions(t *testing.T) { diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 1abf83057..d6983b121 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -20,6 +20,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -598,6 +599,36 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, add_counters map[strin return nil } +func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bool, error) { + // we should force pull if the blocklist decisions are about to expire or there's no decision in the db + alertQuery := a.dbClient.Ent.Alert.Query() + alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name))) + alertQuery.Order(ent.Desc(alert.FieldCreatedAt)) + alertInstance, err := alertQuery.First(context.Background()) + if err != nil { + if ent.IsNotFound(err) { + log.Debugf("no alert found for %s, force refresh", *blocklist.Name) + return true, nil + } + return false, errors.Wrap(err, "while getting alert") + } + decisionQuery := a.dbClient.Ent.Decision.Query() + decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID))) + firstDecision, err := decisionQuery.First(context.Background()) + if err != nil { + if ent.IsNotFound(err) { + log.Debugf("no decision found for %s, force refresh", *blocklist.Name) + return true, nil + } + return false, errors.Wrap(err, "while getting decision") + } + if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) { + log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name) + return true, nil + } + return false, nil +} + func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int) error { if links == nil { return nil @@ -607,7 +638,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink } // we must use a different http client than apiClient's because the transport of apiClient is jwtTransport or here we have signed apis that are incompatibles // we can use the same baseUrl as the urls are absolute and the parse will take care of it - defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", &http.Client{}) + defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", nil) if err != nil { return errors.Wrap(err, "while creating default client") } @@ -620,10 +651,34 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink log.Warningf("blocklist has no duration") continue } - decisions, err := defaultClient.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist) + forcePull, err := a.ShouldForcePullBlocklist(blocklist) + if err != nil { + return errors.Wrapf(err, "while checking if we should force pull blocklist %s", *blocklist.Name) + } + blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name) + var lastPullTimestamp *string + if !forcePull { + lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) + if err != nil { + return errors.Wrapf(err, "while getting last pull timestamp for blocklist %s", *blocklist.Name) + } + } + decisions, has_changed, err := defaultClient.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp) if err != nil { return errors.Wrapf(err, "while getting decisions from blocklist %s", *blocklist.Name) } + if !has_changed { + if lastPullTimestamp == nil { + log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name) + } else { + log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp) + } + continue + } + err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat)) + if err != nil { + return errors.Wrapf(err, "while setting last pull timestamp for blocklist %s", *blocklist.Name) + } if len(decisions) == 0 { log.Infof("blocklist %s has no decisions", *blocklist.Name) continue diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 740776eb0..9168d59ef 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "net/url" "os" "reflect" @@ -568,14 +569,14 @@ func TestAPICPullTop(t *testing.T) { Blocklists: []*modelscapi.BlocklistLink{ { URL: types.StrPtr("http://api.crowdsec.net/blocklist1"), - Name: types.StrPtr("crowdsecurity/http-bf"), + Name: types.StrPtr("blocklist1"), Scope: types.StrPtr("Ip"), Remediation: types.StrPtr("ban"), Duration: types.StrPtr("24h"), }, { URL: types.StrPtr("http://api.crowdsec.net/blocklist2"), - Name: types.StrPtr("crowdsecurity/ssh-bf"), + Name: types.StrPtr("blocklist2"), Scope: types.StrPtr("Ip"), Remediation: types.StrPtr("ban"), Duration: types.StrPtr("24h"), @@ -622,19 +623,160 @@ func TestAPICPullTop(t *testing.T) { } assert.Equal(t, 3, len(alertScenario)) assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS_ALIAS]) - assert.Equal(t, 1, alertScenario["lists:crowdsecurity/ssh-bf"]) - assert.Equal(t, 1, alertScenario["lists:crowdsecurity/http-bf"]) + assert.Equal(t, 1, alertScenario["lists:blocklist1"]) + assert.Equal(t, 1, alertScenario["lists:blocklist2"]) for _, decisions := range validDecisions { decisionScenarioFreq[decisions.Scenario]++ } - assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/http-bf"], 1) - assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1) + assert.Equal(t, 1, decisionScenarioFreq["blocklist1"], 1) + assert.Equal(t, 1, decisionScenarioFreq["blocklist2"], 1) assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test1"], 1) assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test2"], 1) } +func TestAPICPullTopBLCacheFirstCall(t *testing.T) { + // no decision in db, no last modified parameter. + api := getAPIC(t) + httpmock.Activate() + defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( + 200, jsonMarshalX( + modelscapi.GetDecisionsStreamResponse{ + New: modelscapi.GetDecisionsStreamResponseNew{ + &modelscapi.GetDecisionsStreamResponseNewItem{ + Scenario: types.StrPtr("crowdsecurity/test1"), + Scope: types.StrPtr("Ip"), + Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ + { + Value: types.StrPtr("1.2.3.4"), + Duration: types.StrPtr("24h"), + }, + }, + }, + }, + Links: &modelscapi.GetDecisionsStreamResponseLinks{ + Blocklists: []*modelscapi.BlocklistLink{ + { + URL: types.StrPtr("http://api.crowdsec.net/blocklist1"), + Name: types.StrPtr("blocklist1"), + Scope: types.StrPtr("Ip"), + Remediation: types.StrPtr("ban"), + Duration: types.StrPtr("24h"), + }, + }, + }, + }, + ), + )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { + assert.Equal(t, "", req.Header.Get("If-Modified-Since")) + return httpmock.NewStringResponse(200, "1.2.3.4"), nil + }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") + require.NoError(t, err) + + apic, err := apiclient.NewDefaultClient( + url, + "/api", + fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), + nil, + ) + require.NoError(t, err) + + api.apiClient = apic + err = api.PullTop() + require.NoError(t, err) + + blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *types.StrPtr("blocklist1")) + lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + require.NoError(t, err) + assert.NotEqual(t, "", *lastPullTimestamp) + + // new call should return 304 and should not change lastPullTimestamp + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { + assert.NotEqual(t, "", req.Header.Get("If-Modified-Since")) + return httpmock.NewStringResponse(304, ""), nil + }) + err = api.PullTop() + require.NoError(t, err) + secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) + require.NoError(t, err) + assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp) +} + +func TestAPICPullTopBLCacheForceCall(t *testing.T) { + api := getAPIC(t) + httpmock.Activate() + defer httpmock.DeactivateAndReset() + // create a decision about to expire. It should force fetch + alertInstance := api.dbClient.Ent.Alert. + Create(). + SetScenario("update list"). + SetSourceScope("list:blocklist1"). + SetSourceValue("list:blocklist1"). + SaveX(context.Background()) + + api.dbClient.Ent.Decision.Create(). + SetOrigin(types.ListOrigin). + SetType("ban"). + SetValue("9.9.9.9"). + SetScope("Ip"). + SetScenario("blocklist1"). + SetUntil(time.Now().Add(time.Hour)). + SetOwnerID(alertInstance.ID). + ExecX(context.Background()) + + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( + 200, jsonMarshalX( + modelscapi.GetDecisionsStreamResponse{ + New: modelscapi.GetDecisionsStreamResponseNew{ + &modelscapi.GetDecisionsStreamResponseNewItem{ + Scenario: types.StrPtr("crowdsecurity/test1"), + Scope: types.StrPtr("Ip"), + Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ + { + Value: types.StrPtr("1.2.3.4"), + Duration: types.StrPtr("24h"), + }, + }, + }, + }, + Links: &modelscapi.GetDecisionsStreamResponseLinks{ + Blocklists: []*modelscapi.BlocklistLink{ + { + URL: types.StrPtr("http://api.crowdsec.net/blocklist1"), + Name: types.StrPtr("blocklist1"), + Scope: types.StrPtr("Ip"), + Remediation: types.StrPtr("ban"), + Duration: types.StrPtr("24h"), + }, + }, + }, + }, + ), + )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { + assert.Equal(t, "", req.Header.Get("If-Modified-Since")) + return httpmock.NewStringResponse(304, ""), nil + }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") + require.NoError(t, err) + + apic, err := apiclient.NewDefaultClient( + url, + "/api", + fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), + nil, + ) + require.NoError(t, err) + + api.apiClient = apic + err = api.PullTop() + require.NoError(t, err) +} + func TestAPICPush(t *testing.T) { tests := []struct { name string