manage force_pull message for one blocklist (#2615)

* manage force_pull message for one blocklist

* fix info message on force pull blocklist
This commit is contained in:
Cristian Nitescu 2023-11-29 11:37:46 +01:00 committed by GitHub
parent 6b0bdc5eeb
commit 7c5cbef51a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 98 additions and 10 deletions

View file

@ -618,12 +618,23 @@ func (a *apic) PullTop(forcePull bool) error {
} }
// update blocklists // update blocklists
if err := a.UpdateBlocklists(data.Links, add_counters); err != nil { if err := a.UpdateBlocklists(data.Links, add_counters, forcePull); err != nil {
return fmt.Errorf("while updating blocklists: %w", err) return fmt.Errorf("while updating blocklists: %w", err)
} }
return nil return nil
} }
// we receive a link to a blocklist, we pull the content of the blocklist and we create one alert
func (a *apic) PullBlocklist(blocklist *modelscapi.BlocklistLink, forcePull bool) error {
add_counters, _ := makeAddAndDeleteCounters()
if err := a.UpdateBlocklists(&modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{blocklist},
}, add_counters, forcePull); err != nil {
return fmt.Errorf("while pulling blocklist: %w", err)
}
return nil
}
// if decisions is whitelisted: return representation of the whitelist ip or cidr // if decisions is whitelisted: return representation of the whitelist ip or cidr
// if not whitelisted: empty string // if not whitelisted: empty string
func (a *apic) whitelistedBy(decision *models.Decision) string { func (a *apic) whitelistedBy(decision *models.Decision) string {
@ -710,7 +721,7 @@ func (a *apic) ShouldForcePullBlocklist(blocklist *modelscapi.BlocklistLink) (bo
return false, nil return false, nil
} }
func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int) error { func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscapi.BlocklistLink, add_counters map[string]map[string]int, forcePull bool) error {
if blocklist.Scope == nil { if blocklist.Scope == nil {
log.Warningf("blocklist has no scope") log.Warningf("blocklist has no scope")
return nil return nil
@ -719,12 +730,16 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
log.Warningf("blocklist has no duration") log.Warningf("blocklist has no duration")
return nil return nil
} }
forcePull, err := a.ShouldForcePullBlocklist(blocklist) if !forcePull {
if err != nil { _forcePull, err := a.ShouldForcePullBlocklist(blocklist)
return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err) if err != nil {
return fmt.Errorf("while checking if we should force pull blocklist %s: %w", *blocklist.Name, err)
}
forcePull = _forcePull
} }
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name) blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
var lastPullTimestamp *string var lastPullTimestamp *string
var err error
if !forcePull { if !forcePull {
lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName) lastPullTimestamp, err = a.dbClient.GetConfigItem(blocklistConfigItemName)
if err != nil { if err != nil {
@ -764,7 +779,7 @@ func (a *apic) updateBlocklist(client *apiclient.ApiClient, blocklist *modelscap
return nil return nil
} }
func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int) error { func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLinks, add_counters map[string]map[string]int, forcePull bool) error {
if links == nil { if links == nil {
return nil return nil
} }
@ -778,7 +793,7 @@ func (a *apic) UpdateBlocklists(links *modelscapi.GetDecisionsStreamResponseLink
return fmt.Errorf("while creating default client: %w", err) return fmt.Errorf("while creating default client: %w", err)
} }
for _, blocklist := range links.Blocklists { for _, blocklist := range links.Blocklists {
if err := a.updateBlocklist(defaultClient, blocklist, add_counters); err != nil { if err := a.updateBlocklist(defaultClient, blocklist, add_counters, forcePull); err != nil {
return err return err
} }
} }

View file

@ -973,6 +973,37 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestAPICPullBlocklistCall(t *testing.T) {
api := getAPIC(t)
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(200, "1.2.3.4"), nil
})
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
apic, err := apiclient.NewDefaultClient(
url,
"/api",
fmt.Sprintf("crowdsec/%s", version.String()),
nil,
)
require.NoError(t, err)
api.apiClient = apic
err = api.PullBlocklist(&modelscapi.BlocklistLink{
URL: ptr.Of("http://api.crowdsec.net/blocklist1"),
Name: ptr.Of("blocklist1"),
Scope: ptr.Of("Ip"),
Remediation: ptr.Of("ban"),
Duration: ptr.Of("24h"),
}, true)
require.NoError(t, err)
}
func TestAPICPush(t *testing.T) { func TestAPICPush(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View file

@ -11,6 +11,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/modelscapi"
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
) )
@ -19,6 +20,23 @@ type deleteDecisions struct {
Decisions []string `json:"decisions"` Decisions []string `json:"decisions"`
} }
type blocklistLink struct {
// blocklist name
Name string `json:"name"`
// blocklist url
Url string `json:"url"`
// blocklist remediation
Remediation string `json:"remediation"`
// blocklist scope
Scope string `json:"scope,omitempty"`
// blocklist duration
Duration string `json:"duration,omitempty"`
}
type forcePull struct {
Blocklist *blocklistLink `json:"blocklist,omitempty"`
}
func DecisionCmd(message *Message, p *Papi, sync bool) error { func DecisionCmd(message *Message, p *Papi, sync bool) error {
switch message.Header.OperationCmd { switch message.Header.OperationCmd {
case "delete": case "delete":
@ -144,11 +162,35 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
log.Infof("Received reauth command from PAPI, resetting token") log.Infof("Received reauth command from PAPI, resetting token")
p.apiClient.GetClient().Transport.(*apiclient.JWTTransport).ResetToken() p.apiClient.GetClient().Transport.(*apiclient.JWTTransport).ResetToken()
case "force_pull": case "force_pull":
log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") data, err := json.Marshal(message.Data)
err := p.apic.PullTop(true)
if err != nil { if err != nil {
return fmt.Errorf("failed to force pull operation: %s", err) return err
} }
forcePullMsg := forcePull{}
if err := json.Unmarshal(data, &forcePullMsg); err != nil {
return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err)
}
if forcePullMsg.Blocklist == nil {
log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists")
err = p.apic.PullTop(true)
if err != nil {
return fmt.Errorf("failed to force pull operation: %s", err)
}
} else {
log.Infof("Received force_pull command from PAPI, pulling blocklist %s", forcePullMsg.Blocklist.Name)
err = p.apic.PullBlocklist(&modelscapi.BlocklistLink{
Name: &forcePullMsg.Blocklist.Name,
URL: &forcePullMsg.Blocklist.Url,
Remediation: &forcePullMsg.Blocklist.Remediation,
Scope: &forcePullMsg.Blocklist.Scope,
Duration: &forcePullMsg.Blocklist.Duration,
}, true)
if err != nil {
return fmt.Errorf("failed to force pull operation: %s", err)
}
}
default: default:
return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType) return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType)
} }