optimize blocklist fetch (#2039)

This commit is contained in:
Cristian Nitescu 2023-02-13 15:06:14 +01:00 committed by GitHub
parent f280505eaa
commit ecb32d74c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 280 additions and 29 deletions

View file

@ -158,7 +158,7 @@ func newResponse(r *http.Response) *Response {
}
func CheckResponse(r *http.Response) error {
if c := r.StatusCode; 200 <= c && c <= 299 {
if c := r.StatusCode; 200 <= c && c <= 299 || c == 304 {
return nil
}
errorResponse := &ErrorResponse{}

View file

@ -4,7 +4,6 @@ import (
"bufio"
"context"
"fmt"
"io"
"net/http"
"github.com/crowdsecurity/crowdsec/pkg/models"
@ -150,29 +149,58 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
return &v2Decisions, resp, nil
}
func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink) ([]*models.Decision, error) {
func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp *string) ([]*models.Decision, bool, error) {
if blocklist.URL == nil {
return nil, errors.New("blocklist URL is nil")
return nil, false, errors.New("blocklist URL is nil")
}
log.Debugf("Fetching blocklist %s", *blocklist.URL)
req, err := s.client.NewRequest(http.MethodGet, *blocklist.URL, nil)
client := http.Client{}
req, err := http.NewRequest(http.MethodGet, *blocklist.URL, nil)
if err != nil {
return nil, err
return nil, false, err
}
pr, pw := io.Pipe()
defer pr.Close()
go func() {
defer pw.Close()
_, err = s.client.Do(ctx, req, pw)
if err != nil {
log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err)
if lastPullTimestamp != nil {
req.Header.Set("If-Modified-Since", *lastPullTimestamp)
}
req = req.WithContext(ctx)
log.Debugf("[URL] %s %s", req.Method, req.URL)
// we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc
resp, err := client.Do(req)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
// If we got an error, and the context has been canceled,
// the context's error is probably more useful.
select {
case <-ctx.Done():
return nil, false, ctx.Err()
default:
}
}()
// If the error type is *url.Error, sanitize its URL before returning.
log.Errorf("Error fetching blocklist %s: %s", *blocklist.URL, err)
return nil, false, err
}
if resp.StatusCode == http.StatusNotModified {
if lastPullTimestamp != nil {
log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, *lastPullTimestamp)
} else {
log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL)
}
return nil, false, nil
}
if resp.StatusCode != http.StatusOK {
log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL)
return nil, false, nil
}
decisions := make([]*models.Decision, 0)
scanner := bufio.NewScanner(pr)
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
decision := scanner.Text()
decisions = append(decisions, &models.Decision{
@ -185,7 +213,9 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
})
}
return decisions, nil
// here the upper go routine is finished because scanner.Scan() is blocking until pw.Close() is called
// so it's safe to use the isModified variable here
return decisions, true, nil
}
func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOpts) (*models.DecisionsStreamResponse, *Response, error) {

View file

@ -11,6 +11,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/modelscapi"
"github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -377,9 +378,11 @@ func TestDecisionsFromBlocklist(t *testing.T) {
defer teardown()
mux.HandleFunc("/blocklist", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.Header.Get("X-Api-Key"), "ixu")
testMethod(t, r, http.MethodGet)
if r.Header.Get("If-Modified-Since") == "Sun, 01 Jan 2023 01:01:01 GMT" {
w.WriteHeader(http.StatusNotModified)
return
}
if r.Method == http.MethodGet {
w.WriteHeader(http.StatusOK)
w.Write([]byte("1.2.3.4\r\n1.2.3.5"))
@ -407,7 +410,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
tnameBlocklist := "blocklist1"
tremediationBlocklist := "ban"
tscopeBlocklist := "ip"
turlBlocklist := "/v3/blocklist"
turlBlocklist := urlx + "/v3/blocklist"
torigin := "lists"
expected := []*models.Decision{
{
@ -427,14 +430,15 @@ func TestDecisionsFromBlocklist(t *testing.T) {
Origin: &torigin,
},
}
decisions, err := newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
decisions, isModified, err := newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
URL: &turlBlocklist,
Scope: &tscopeBlocklist,
Remediation: &tremediationBlocklist,
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
})
}, nil)
require.NoError(t, err)
assert.True(t, isModified)
log.Infof("decision1: %+v", decisions[0])
log.Infof("expected1: %+v", expected[0])
@ -448,6 +452,26 @@ func TestDecisionsFromBlocklist(t *testing.T) {
if !reflect.DeepEqual(decisions, expected) {
t.Fatalf("returned %+v, want %+v", decisions, expected)
}
// test cache control
_, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
URL: &turlBlocklist,
Scope: &tscopeBlocklist,
Remediation: &tremediationBlocklist,
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
}, types.StrPtr("Sun, 01 Jan 2023 01:01:01 GMT"))
require.NoError(t, err)
assert.False(t, isModified)
_, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
URL: &turlBlocklist,
Scope: &tscopeBlocklist,
Remediation: &tremediationBlocklist,
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
}, types.StrPtr("Mon, 02 Jan 2023 01:01:01 GMT"))
require.NoError(t, err)
assert.True(t, isModified)
}
func TestDeleteDecisions(t *testing.T) {

View file

@ -20,6 +20,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
"github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
"github.com/crowdsecurity/crowdsec/pkg/models"
@ -598,6 +599,36 @@ func (a *apic) SaveAlerts(alertsFromCapi []*models.Alert, add_counters map[strin
return nil
}
func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bool, error) {
// we should force pull if the blocklist decisions are about to expire or there's no decision in the db
alertQuery := a.dbClient.Ent.Alert.Query()
alertQuery.Where(alert.SourceScopeEQ(fmt.Sprintf("%s:%s", types.ListOrigin, *blocklist.Name)))
alertQuery.Order(ent.Desc(alert.FieldCreatedAt))
alertInstance, err := alertQuery.First(context.Background())
if err != nil {
if ent.IsNotFound(err) {
log.Debugf("no alert found for %s, force refresh", *blocklist.Name)
return true, nil
}
return false, errors.Wrap(err, "while getting alert")
}
decisionQuery := a.dbClient.Ent.Decision.Query()
decisionQuery.Where(decision.HasOwnerWith(alert.IDEQ(alertInstance.ID)))
firstDecision, err := decisionQuery.First(context.Background())
if err != nil {
if ent.IsNotFound(err) {
log.Debugf("no decision found for %s, force refresh", *blocklist.Name)
return true, nil
}
return false, errors.Wrap(err, "while getting decision")
}
if firstDecision == nil || firstDecision.Until == nil || firstDecision.Until.Sub(time.Now().UTC()) < (a.pullInterval+15*time.Minute) {
log.Debugf("at least one decision found for %s, expire soon, force refresh", *blocklist.Name)
return true, nil
}
return false, nil
}
func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int) error {
if links == nil {
return nil
@ -607,7 +638,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
}
// we must use a different http client than apiClient's because the transport of apiClient is jwtTransport or here we have signed apis that are incompatibles
// we can use the same baseUrl as the urls are absolute and the parse will take care of it
defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", &http.Client{})
defaultClient, err := apiclient.NewDefaultClient(a.apiClient.BaseURL, "", "", nil)
if err != nil {
return errors.Wrap(err, "while creating default client")
}
@ -620,10 +651,34 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
log.Warningf("blocklist has no duration")
continue
}
decisions, err := defaultClient.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist)
forcePull, err := a.ShouldForcePullBlocklist(blocklist)
if err != nil {
return errors.Wrapf(err, "while checking if we should force pull blocklist %s", *blocklist.Name)
}
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
var lastPullTimestamp *string
if !forcePull {
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
if err != nil {
return errors.Wrapf(err, "while getting last pull timestamp for blocklist %s", *blocklist.Name)
}
}
decisions, has_changed, err := defaultClient.Decisions.GetDecisionsFromBlocklist(context.Background(), blocklist, lastPullTimestamp)
if err != nil {
return errors.Wrapf(err, "while getting decisions from blocklist %s", *blocklist.Name)
}
if !has_changed {
if lastPullTimestamp == nil {
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
} else {
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
}
continue
}
err = a.dbClient.SetConfigItem(blocklistConfigItemName, time.Now().UTC().Format(http.TimeFormat))
if err != nil {
return errors.Wrapf(err, "while setting last pull timestamp for blocklist %s", *blocklist.Name)
}
if len(decisions) == 0 {
log.Infof("blocklist %s has no decisions", *blocklist.Name)
continue

View file

@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"reflect"
@ -568,14 +569,14 @@ func TestAPICPullTop(t *testing.T) {
Blocklists: []*modelscapi.BlocklistLink{
{
URL: types.StrPtr("http://api.crowdsec.net/blocklist1"),
Name: types.StrPtr("crowdsecurity/http-bf"),
Name: types.StrPtr("blocklist1"),
Scope: types.StrPtr("Ip"),
Remediation: types.StrPtr("ban"),
Duration: types.StrPtr("24h"),
},
{
URL: types.StrPtr("http://api.crowdsec.net/blocklist2"),
Name: types.StrPtr("crowdsecurity/ssh-bf"),
Name: types.StrPtr("blocklist2"),
Scope: types.StrPtr("Ip"),
Remediation: types.StrPtr("ban"),
Duration: types.StrPtr("24h"),
@ -622,19 +623,160 @@ func TestAPICPullTop(t *testing.T) {
}
assert.Equal(t, 3, len(alertScenario))
assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS_ALIAS])
assert.Equal(t, 1, alertScenario["lists:crowdsecurity/ssh-bf"])
assert.Equal(t, 1, alertScenario["lists:crowdsecurity/http-bf"])
assert.Equal(t, 1, alertScenario["lists:blocklist1"])
assert.Equal(t, 1, alertScenario["lists:blocklist2"])
for _, decisions := range validDecisions {
decisionScenarioFreq[decisions.Scenario]++
}
assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/http-bf"], 1)
assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1)
assert.Equal(t, 1, decisionScenarioFreq["blocklist1"], 1)
assert.Equal(t, 1, decisionScenarioFreq["blocklist2"], 1)
assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test1"], 1)
assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test2"], 1)
}
func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
// no decision in db, no last modified parameter.
api := getAPIC(t)
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
200, jsonMarshalX(
modelscapi.GetDecisionsStreamResponse{
New: modelscapi.GetDecisionsStreamResponseNew{
&modelscapi.GetDecisionsStreamResponseNewItem{
Scenario: types.StrPtr("crowdsecurity/test1"),
Scope: types.StrPtr("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Value: types.StrPtr("1.2.3.4"),
Duration: types.StrPtr("24h"),
},
},
},
},
Links: &modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{
{
URL: types.StrPtr("http://api.crowdsec.net/blocklist1"),
Name: types.StrPtr("blocklist1"),
Scope: types.StrPtr("Ip"),
Remediation: types.StrPtr("ban"),
Duration: types.StrPtr("24h"),
},
},
},
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(200, "1.2.3.4"), nil
})
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
apic, err := apiclient.NewDefaultClient(
url,
"/api",
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil,
)
require.NoError(t, err)
api.apiClient = apic
err = api.PullTop()
require.NoError(t, err)
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *types.StrPtr("blocklist1"))
lastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
require.NoError(t, err)
assert.NotEqual(t, "", *lastPullTimestamp)
// new call should return 304 and should not change lastPullTimestamp
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
assert.NotEqual(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(304, ""), nil
})
err = api.PullTop()
require.NoError(t, err)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
require.NoError(t, err)
assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp)
}
func TestAPICPullTopBLCacheForceCall(t *testing.T) {
api := getAPIC(t)
httpmock.Activate()
defer httpmock.DeactivateAndReset()
// create a decision about to expire. It should force fetch
alertInstance := api.dbClient.Ent.Alert.
Create().
SetScenario("update list").
SetSourceScope("list:blocklist1").
SetSourceValue("list:blocklist1").
SaveX(context.Background())
api.dbClient.Ent.Decision.Create().
SetOrigin(types.ListOrigin).
SetType("ban").
SetValue("9.9.9.9").
SetScope("Ip").
SetScenario("blocklist1").
SetUntil(time.Now().Add(time.Hour)).
SetOwnerID(alertInstance.ID).
ExecX(context.Background())
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
200, jsonMarshalX(
modelscapi.GetDecisionsStreamResponse{
New: modelscapi.GetDecisionsStreamResponseNew{
&modelscapi.GetDecisionsStreamResponseNewItem{
Scenario: types.StrPtr("crowdsecurity/test1"),
Scope: types.StrPtr("Ip"),
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Value: types.StrPtr("1.2.3.4"),
Duration: types.StrPtr("24h"),
},
},
},
},
Links: &modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{
{
URL: types.StrPtr("http://api.crowdsec.net/blocklist1"),
Name: types.StrPtr("blocklist1"),
Scope: types.StrPtr("Ip"),
Remediation: types.StrPtr("ban"),
Duration: types.StrPtr("24h"),
},
},
},
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(304, ""), nil
})
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
apic, err := apiclient.NewDefaultClient(
url,
"/api",
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil,
)
require.NoError(t, err)
api.apiClient = apic
err = api.PullTop()
require.NoError(t, err)
}
func TestAPICPush(t *testing.T) {
tests := []struct {
name string