diff --git a/cmd/crowdsec-cli/capi.go b/cmd/crowdsec-cli/capi.go index 90ae4697e..0cecad0fa 100644 --- a/cmd/crowdsec-cli/capi.go +++ b/cmd/crowdsec-cli/capi.go @@ -20,7 +20,7 @@ import ( "gopkg.in/yaml.v2" ) -const CAPIBaseURL string = "https://api.dev.crowdsec.net/" +const CAPIBaseURL string = "https://api.crowdsec.net/" const CAPIURLPrefix = "v3" func NewCapiCmd() *cobra.Command { diff --git a/cmd/crowdsec-cli/papi.go b/cmd/crowdsec-cli/papi.go index 9951d5db2..176b81370 100644 --- a/cmd/crowdsec-cli/papi.go +++ b/cmd/crowdsec-cli/papi.go @@ -116,7 +116,7 @@ func NewPapiSyncCmd() *cobra.Command { } t.Go(papi.SyncDecisions) - err = papi.PullOnce(time.Time{}) + err = papi.PullOnce(time.Time{}, true) if err != nil { log.Fatalf("unable to sync decisions: %s", err) diff --git a/cmd/crowdsec/serve.go b/cmd/crowdsec/serve.go index f3c053a63..4c9ff55a3 100644 --- a/cmd/crowdsec/serve.go +++ b/cmd/crowdsec/serve.go @@ -18,7 +18,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -//nolint: deadcode,unused // debugHandler is kept as a dev convenience: it shuts down and serialize internal state +//nolint:deadcode,unused // debugHandler is kept as a dev convenience: it shuts down and serialize internal state func debugHandler(sig os.Signal, cConfig *csconfig.Config) error { var ( tmpFile string @@ -356,7 +356,6 @@ func Serve(cConfig *csconfig.Config, apiReady chan bool, agentReady chan bool) e if !sent || err != nil { log.Errorf("Failed to notify(sent: %v): %v", sent, err) } - // wait for signals return HandleSignals(cConfig) } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index e811f6004..5e8c714a1 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -61,6 +61,7 @@ type apic struct { credentials *csconfig.ApiCredentialsCfg scenarioList []string consoleConfig *csconfig.ConsoleConfig + isPulling chan bool whitelists *csconfig.CapiWhitelist } @@ -171,6 +172,7 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con pushIntervalFirst: randomDuration(pushIntervalDefault, pushIntervalDelta), metricsInterval: metricsIntervalDefault, metricsIntervalFirst: randomDuration(metricsIntervalDefault, metricsIntervalDelta), + isPulling: make(chan bool, 1), whitelists: apicWhitelist, } @@ -537,13 +539,26 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio // we receive a list of decisions and links for blocklist and we need to create a list of alerts : // one alert for "community blocklist" // one alert per list we're subscribed to -func (a *apic) PullTop() error { +func (a *apic) PullTop(forcePull bool) error { var err error - if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil { - return err - } else if !lastPullIsOld { - return nil + //A mutex with TryLock would be a bit simpler + //But go does not guarantee that TryLock will be able to acquire the lock even if it is available + select { + case a.isPulling <- true: + defer func() { + <-a.isPulling + }() + default: + return errors.New("pull already in progress") + } + + if !forcePull { + if lastPullIsOld, err := a.CAPIPullIsOld(); err != nil { + return err + } else if !lastPullIsOld { + return nil + } } log.Infof("Starting community-blocklist update") @@ -780,7 +795,7 @@ func (a *apic) Pull() error { } time.Sleep(1 * time.Second) } - if err := a.PullTop(); err != nil { + if err := a.PullTop(false); err != nil { log.Errorf("capi pull top: %s", err) } @@ -791,7 +806,7 @@ func (a *apic) Pull() error { select { case <-ticker.C: ticker.Reset(a.pullInterval) - if err := a.PullTop(); err != nil { + if err := a.PullTop(false); err != nil { log.Errorf("capi pull top: %s", err) continue } diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index df3d2e965..965ef0378 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -64,6 +64,7 @@ func getAPIC(t *testing.T) *apic { ShareCustomScenarios: types.BoolPtr(false), ShareContext: types.BoolPtr(false), }, + isPulling: make(chan bool, 1), } } @@ -666,7 +667,7 @@ func TestAPICWhitelists(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop() + err = api.PullTop(false) require.NoError(t, err) assertTotalDecisionCount(t, api.dbClient, 5) //2 from FIRE + 2 from bl + 1 existing @@ -797,7 +798,7 @@ func TestAPICPullTop(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop() + err = api.PullTop(false) require.NoError(t, err) assertTotalDecisionCount(t, api.dbClient, 5) @@ -879,7 +880,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop() + err = api.PullTop(false) require.NoError(t, err) blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *types.StrPtr("blocklist1")) @@ -892,7 +893,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { assert.NotEqual(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(304, ""), nil }) - err = api.PullTop() + err = api.PullTop(false) require.NoError(t, err) secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) require.NoError(t, err) @@ -966,7 +967,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { require.NoError(t, err) api.apiClient = apic - err = api.PullTop() + err = api.PullTop(false) require.NoError(t, err) } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 73758eae7..278a776be 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -420,6 +420,9 @@ func (s *APIServer) Close() { if s.apic != nil { s.apic.Shutdown() // stop apic first since it use dbClient } + if s.papi != nil { + s.papi.Shutdown() // papi also uses the dbClient + } s.dbClient.Ent.Close() if s.flushScheduler != nil { s.flushScheduler.Stop() diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 581439240..33028b13c 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -15,7 +15,6 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/pkg/errors" - "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "gopkg.in/tomb.v2" ) @@ -29,9 +28,10 @@ const ( ) var ( - operationMap = map[string]func(*Message, *Papi) error{ - "decision": DecisionCmd, - "alert": AlertCmd, + operationMap = map[string]func(*Message, *Papi, bool) error{ + "decision": DecisionCmd, + "alert": AlertCmd, + "management": ManagementCmd, } ) @@ -71,6 +71,7 @@ type Papi struct { SyncInterval time.Duration consoleConfig *csconfig.ConsoleConfig Logger *log.Entry + apic *apic } type PapiPermCheckError struct { @@ -85,7 +86,7 @@ type PapiPermCheckSuccess struct { func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.ConsoleConfig, logLevel log.Level) (*Papi, error) { - logger := logrus.New() + logger := log.New() if err := types.ConfigureLogger(logger); err != nil { return &Papi{}, fmt.Errorf("creating papi logger: %s", err) } @@ -118,6 +119,7 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons pullTomb: tomb.Tomb{}, syncTomb: tomb.Tomb{}, apiClient: apic.apiClient, + apic: apic, consoleConfig: consoleConfig, Logger: logger.WithFields(log.Fields{"interval": SyncInterval.Seconds(), "source": "papi"}), } @@ -125,7 +127,7 @@ func NewPAPI(apic *apic, dbClient *database.Client, consoleConfig *csconfig.Cons return papi, nil } -func (p *Papi) handleEvent(event longpollclient.Event) error { +func (p *Papi) handleEvent(event longpollclient.Event, sync bool) error { logger := p.Logger.WithField("request-id", event.RequestId) logger.Debugf("message received: %+v", event.Data) message := &Message{} @@ -141,7 +143,7 @@ func (p *Papi) handleEvent(event longpollclient.Event) error { if operationFunc, ok := operationMap[message.Header.OperationType]; ok { logger.Debugf("Calling operation '%s'", message.Header.OperationType) - err := operationFunc(message, p) + err := operationFunc(message, p, sync) if err != nil { return fmt.Errorf("'%s %s failed: %s", message.Header.OperationType, message.Header.OperationCmd, err) } @@ -192,7 +194,7 @@ func reverse(s []longpollclient.Event) []longpollclient.Event { return a } -func (p *Papi) PullOnce(since time.Time) error { +func (p *Papi) PullOnce(since time.Time, sync bool) error { events, err := p.Client.PullOnce(since) if err != nil { return err @@ -202,7 +204,7 @@ func (p *Papi) PullOnce(since time.Time) error { eventsCount := len(events) p.Logger.Infof("received %d events", eventsCount) for i, event := range reversedEvents { - if err := p.handleEvent(event); err != nil { + if err := p.handleEvent(event, sync); err != nil { p.Logger.WithField("request-id", event.RequestId).Errorf("failed to handle event: %s", err) } p.Logger.Debugf("handled event %d/%d", i, eventsCount) @@ -251,7 +253,7 @@ func (p *Papi) Pull() error { return errors.Wrap(err, "failed to marshal last timestamp") } - err = p.handleEvent(event) + err = p.handleEvent(event, false) if err != nil { logger.Errorf("failed to handle event: %s", err) continue diff --git a/pkg/apiserver/papi_cmd.go b/pkg/apiserver/papi_cmd.go index 002d103fc..3635bc207 100644 --- a/pkg/apiserver/papi_cmd.go +++ b/pkg/apiserver/papi_cmd.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/pkg/errors" @@ -16,7 +17,7 @@ type deleteDecisions struct { Decisions []string `json:"decisions"` } -func DecisionCmd(message *Message, p *Papi) error { +func DecisionCmd(message *Message, p *Papi, sync bool) error { switch message.Header.OperationCmd { case "delete": @@ -64,7 +65,7 @@ func DecisionCmd(message *Message, p *Papi) error { return nil } -func AlertCmd(message *Message, p *Papi) error { +func AlertCmd(message *Message, p *Papi, sync bool) error { switch message.Header.OperationCmd { case "add": data, err := json.Marshal(message.Data) @@ -130,3 +131,24 @@ func AlertCmd(message *Message, p *Papi) error { return nil } + +func ManagementCmd(message *Message, p *Papi, sync bool) error { + if sync { + log.Infof("Ignoring management command from PAPI in sync mode") + return nil + } + switch message.Header.OperationCmd { + case "reauth": + log.Infof("Received reauth command from PAPI, resetting token") + p.apiClient.GetClient().Transport.(*apiclient.JWTTransport).ResetToken() + case "force_pull": + log.Infof("Received force_pull command from PAPI, pulling community and 3rd-party blocklists") + err := p.apic.PullTop(true) + if err != nil { + return fmt.Errorf("failed to force pull operation: %s", err) + } + default: + return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType) + } + return nil +}