diff --git a/cmd/crowdsec-cli/notifications.go b/cmd/crowdsec-cli/notifications.go index 487cc96fd..9d4e8610c 100644 --- a/cmd/crowdsec-cli/notifications.go +++ b/cmd/crowdsec-cli/notifications.go @@ -1,20 +1,28 @@ package main import ( + "context" "encoding/csv" "encoding/json" "fmt" "io/fs" + "net/url" "os" "path/filepath" + "strconv" "strings" + "time" + "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" + "github.com/crowdsecurity/crowdsec/pkg/cwversion" + "github.com/go-openapi/strfmt" "github.com/olekukonko/tablewriter" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "gopkg.in/tomb.v2" ) type NotificationsCfg struct { @@ -135,6 +143,81 @@ func NewNotificationsCmd() *cobra.Command { }, } cmdNotifications.AddCommand(cmdNotificationsInspect) + var cmdNotificationsReinject = &cobra.Command{ + Use: "reinject", + Short: "reinject alerts into notifications system", + Long: `Reinject alerts into notifications system`, + Example: `cscli notifications reinject `, + Args: cobra.ExactArgs(2), + DisableAutoGenTag: true, + Run: func(cmd *cobra.Command, args []string) { + var ( + pluginBroker csplugin.PluginBroker + pluginTomb tomb.Tomb + ) + if len(args) != 2 { + printHelp(cmd) + return + } + id, err := strconv.Atoi(args[0]) + if err != nil { + log.Fatalf("bad alert id %s", args[0]) + } + if err := csConfig.LoadAPIClient(); err != nil { + log.Fatalf("loading api client: %s", err.Error()) + } + if csConfig.API.Client == nil { + log.Fatalln("There is no configuration on 'api_client:'") + } + if csConfig.API.Client.Credentials == nil { + log.Fatalf("Please provide credentials for the API in '%s'", csConfig.API.Client.CredentialsFilePath) + } + apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL) + Client, err = apiclient.NewClient(&apiclient.Config{ + MachineID: csConfig.API.Client.Credentials.Login, + Password: strfmt.Password(csConfig.API.Client.Credentials.Password), + UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), + URL: apiURL, + VersionPrefix: "v1", + }) + + alert, _, err := Client.Alerts.GetByID(context.Background(), id) + if err != nil { + log.Fatalf("can't find alert with id %s: %s", args[0], err) + + } + + err = pluginBroker.Init(csConfig.PluginConfig, csConfig.API.Server.Profiles, csConfig.ConfigPaths) + if err != nil { + log.Fatalf("Can't initialize plugins: %s", err.Error()) + } + + pluginTomb.Go(func() error { + pluginBroker.Run(&pluginTomb) + fmt.Printf("\nreturned\n") + return nil + }) + + loop: + for { + select { + case pluginBroker.PluginChannel <- csplugin.ProfileAlert{ + ProfileID: 1, + Alert: alert, + }: + break loop + default: + time.Sleep(50 * time.Millisecond) + log.Info("sleeping\n") + + } + } + pluginTomb.Kill(errors.New("terminating")) + pluginTomb.Wait() + + }, + } + cmdNotifications.AddCommand(cmdNotificationsReinject) return cmdNotifications } diff --git a/cmd/crowdsec/api.go b/cmd/crowdsec/api.go index 2dbce31b0..7e955f230 100644 --- a/cmd/crowdsec/api.go +++ b/cmd/crowdsec/api.go @@ -34,6 +34,10 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) { return nil, fmt.Errorf("unable to run local API: %s", err) } log.Info("initiated plugin broker") + apiserver := &apiserver.APIServer{ + URL: "", + TLS: &csconfig.TLSCfg{}, + } apiServer.AttachPluginBroker(&pluginBroker) } diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 44437133a..ffddd3933 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -146,7 +146,7 @@ func TestCreateAlertChannels(t *testing.T) { if err != nil { log.Fatalln(err.Error()) } - apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert) + apiServer.Controller.PluginChannel = make(chan csplugin.ProfileAlert) apiServer.InitController() loginResp, err := LoginToTestAPI(apiServer.router, config) @@ -160,7 +160,7 @@ func TestCreateAlertChannels(t *testing.T) { wg.Add(1) go func() { - pd = <-apiServer.controller.PluginChannel + pd = <-apiServer.Controller.PluginChannel wg.Done() }() diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index e84655049..4d9bf84d4 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -30,9 +30,9 @@ var ( type APIServer struct { URL string TLS *csconfig.TLSCfg + Controller *controllers.Controller dbClient *database.Client logFile string - controller *controllers.Controller flushScheduler *gocron.Scheduler router *gin.Engine httpServer *http.Server @@ -227,7 +227,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) { TLS: config.TLS, logFile: logFile, dbClient: dbClient, - controller: controller, + Controller: controller, flushScheduler: flushScheduler, router: router, apic: apiClient, @@ -323,10 +323,10 @@ func (s *APIServer) Shutdown() error { } func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) { - s.controller.PluginChannel = broker.PluginChannel + s.Controller.PluginChannel = broker.PluginChannel } func (s *APIServer) InitController() error { - err := s.controller.Init() + err := s.Controller.Init() return err } diff --git a/pkg/csplugin/broker.go b/pkg/csplugin/broker.go index 6e2536cde..0c752b45b 100644 --- a/pkg/csplugin/broker.go +++ b/pkg/csplugin/broker.go @@ -96,10 +96,11 @@ func (pb *PluginBroker) Kill() { } } -func (pb *PluginBroker) Run(tomb *tomb.Tomb) { +func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) { //we get signaled via the channel when notifications need to be delivered to plugin (via the watcher) - pb.watcher.Start(tomb) + pb.watcher.Start(&tomb.Tomb{}) for { + fmt.Printf("looping") select { case profileAlert := <-pb.PluginChannel: pb.addProfileAlert(profileAlert) @@ -116,8 +117,25 @@ func (pb *PluginBroker) Run(tomb *tomb.Tomb) { log.WithField("plugin:", pluginName).Error(err) } }() - - case <-tomb.Dying(): + case <-pluginTomb.Dying(): + pb.watcher.tomb.Kill(errors.New("Terminating")) + loop: + for { + select { + case pluginName := <-pb.watcher.PluginEvents: + // this can be ran in goroutine, but then locks will be needed + pluginMutex.Lock() + log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName) + tmpAlerts := pb.alertsByPluginName[pluginName] + pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0) + pluginMutex.Unlock() + if err := pb.pushNotificationsToPlugin(pluginName, tmpAlerts); err != nil { + log.WithField("plugin:", pluginName).Error(err) + } + case <-pb.watcher.tomb.Dead(): + break loop + } + } log.Info("killing all plugins") pb.Kill() return @@ -133,7 +151,10 @@ func (pb *PluginBroker) addProfileAlert(profileAlert ProfileAlert) { pluginMutex.Lock() pb.alertsByPluginName[pluginName] = append(pb.alertsByPluginName[pluginName], profileAlert.Alert) pluginMutex.Unlock() - pb.watcher.Inserts <- pluginName + if _, ok := pb.watcher.PluginConfigByName[pluginName]; ok { + curr, _ := pb.watcher.AlertCountByPluginName.Get(pluginName) + pb.watcher.AlertCountByPluginName.Set(pluginName, curr+1) + } } } func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool { diff --git a/pkg/csplugin/watcher.go b/pkg/csplugin/watcher.go index 955133a4a..f0b185e57 100644 --- a/pkg/csplugin/watcher.go +++ b/pkg/csplugin/watcher.go @@ -73,11 +73,6 @@ func (pw *PluginWatcher) Start(tomb *tomb.Tomb) { return nil }) } - - pw.tomb.Go(func() error { - pw.watchPluginAlertCounts() - return nil - }) } func (pw *PluginWatcher) watchPluginTicker(pluginName string) { @@ -139,21 +134,8 @@ func (pw *PluginWatcher) watchPluginTicker(pluginName string) { } case <-pw.tomb.Dying(): ticker.Stop() - return - } - } -} - -func (pw *PluginWatcher) watchPluginAlertCounts() { - for { - select { - case pluginName := <-pw.Inserts: - //we only "count" pending alerts, and watchPluginTicker is actually going to send it - if _, ok := pw.PluginConfigByName[pluginName]; ok { - curr, _ := pw.AlertCountByPluginName.Get(pluginName) - pw.AlertCountByPluginName.Set(pluginName, curr+1) - } - case <-pw.tomb.Dying(): + pw.PluginEvents <- pluginName + log.Tracef("sending alerts to %s", pluginName) return } }