diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index a199e2892..7e4347c2a 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -618,12 +618,23 @@ func (a *apic) PullTop(forcePull bool) error { } // update blocklists - if err := a.UpdateBlocklists(data.Links, add_counters); err != nil { + if err := a.UpdateBlocklists(data.Links, add_counters, forcePull); err != nil { return fmt.Errorf("while updating blocklists: %w", err) } return nil } +// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert +func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error { + add_counters, _ := makeAddAndDeleteCounters() + if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{ + Blocklists: []*modelscapi.BlocklistLink{blocklist}, + }, add_counters, forcePull); err != nil { + return fmt.Errorf("while pulling blocklist: %w", err) + } + return nil +} + // if decisions is whitelisted: return representation of the whitelist ip or cidr // if not whitelisted: empty string func (a *apic) whitelistedBy(decision *models.Decision) string { @@ -710,7 +721,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo return false, nil } -func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int) error { +func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int, forcePull bool) error { if blocklist.Scope == nil { log.Warningf("blocklist has no scope") return nil @@ -719,12 +730,16 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap log.Warningf("blocklist has no duration") return nil } - forcePull, err := a.ShouldForcePullBlocklist(blocklist) - if err != nil { - return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err) + if !forcePull { + _forcePull, err := a.ShouldForcePullBlocklist(blocklist) + if err != nil { + return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err) + } + forcePull = _forcePull } blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name) var lastPullTimestamp *string + var err error if !forcePull { lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) if err != nil { @@ -764,7 +779,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap return nil } -func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int) error { +func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int, forcePull bool) error { if links == nil { return nil } @@ -778,7 +793,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink return fmt.Errorf("while creating default client: %w", err) } for _, blocklist := range links.Blocklists { - if err := a.updateBlocklist(defaultClient, blocklist, add_counters); err != nil { + if err := a.updateBlocklist(defaultClient, blocklist, add_counters, forcePull); err != nil { return err } } diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 97127cad0..736a690c9 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -973,6 +973,37 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { require.NoError(t, err) } +func TestAPICPullBlocklistCall(t *testing.T) { + api := getAPIC(t) + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { + assert.Equal(t, "", req.Header.Get("If-Modified-Since")) + return httpmock.NewStringResponse(200, "1.2.3.4"), nil + }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") + require.NoError(t, err) + + apic, err := apiclient.NewDefaultClient( + url, + "/api", + fmt.Sprintf("crowdsec/%s", version.String()), + nil, + ) + require.NoError(t, err) + + api.apiClient = apic + err = api.PullBlocklist(&modelscapi.BlocklistLink{ + URL: ptr.Of("http://api.crowdsec.net/blocklist1"), + Name: ptr.Of("blocklist1"), + Scope: ptr.Of("Ip"), + Remediation: ptr.Of("ban"), + Duration: ptr.Of("24h"), + }, true) + require.NoError(t, err) +} + func TestAPICPush(t *testing.T) { tests := []struct { name string diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index b57e7655d..6ab8f3734 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -11,6 +11,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/modelscapi" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -19,6 +20,23 @@ type deleteDecisions struct { Decisions []string `json:"decisions"` } +type blocklistLink struct { + // blocklist name + Name string `json:"name"` + // blocklist url + Url string `json:"url"` + // blocklist remediation + Remediation string `json:"remediation"` + // blocklist scope + Scope string `json:"scope,omitempty"` + // blocklist duration + Duration string `json:"duration,omitempty"` +} + +type forcePull struct { + Blocklist *blocklistLink `json:"blocklist,omitempty"` +} + func DecisionCmd(message *Message, p *Papi, sync bool) error { switch message.Header.OperationCmd { case "delete": @@ -144,11 +162,35 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error { log.Infof("Received reauth command from PAPI, resetting token") p.apiClient.GetClient().Transport.(*apiclient.JWTTransport).ResetToken() case "force_pull": - log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") - err := p.apic.PullTop(true) + data, err := json.Marshal(message.Data) if err != nil { - return fmt.Errorf("failed to force pull operation: %s", err) + return err } + forcePullMsg := forcePull{} + if err := json.Unmarshal(data, &forcePullMsg); err != nil { + return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err) + } + + if forcePullMsg.Blocklist == nil { + log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") + err = p.apic.PullTop(true) + if err != nil { + return fmt.Errorf("failed to force pull operation: %s", err) + } + } else { + log.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name) + err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{ + Name: &forcePullMsg.Blocklist.Name, + URL: &forcePullMsg.Blocklist.Url, + Remediation: &forcePullMsg.Blocklist.Remediation, + Scope: &forcePullMsg.Blocklist.Scope, + Duration: &forcePullMsg.Blocklist.Duration, + }, true) + if err != nil { + return fmt.Errorf("failed to force pull operation: %s", err) + } + } + default: return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType) }