* add a way to test notifications

* fix an antipattern in broker system with tomb
* kill an unneeded goroutine
This commit is contained in:
sabban 2022-05-20 09:50:51 +02:00
parent a2d91119d4
commit ae2767c8a2
6 changed files with 120 additions and 31 deletions

View file

@ -1,20 +1,28 @@
package main
import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"io/fs"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
"github.com/go-openapi/strfmt"
"github.com/olekukonko/tablewriter"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"gopkg.in/tomb.v2"
)
type NotificationsCfg struct {
@ -135,6 +143,81 @@ func NewNotificationsCmd() *cobra.Command {
},
}
cmdNotifications.AddCommand(cmdNotificationsInspect)
var cmdNotificationsReinject = &cobra.Command{
Use: "reinject",
Short: "reinject alerts into notifications system",
Long: `Reinject alerts into notifications system`,
Example: `cscli notifications reinject <alert_id> <plugin_name>`,
Args: cobra.ExactArgs(2),
DisableAutoGenTag: true,
Run: func(cmd *cobra.Command, args []string) {
var (
pluginBroker csplugin.PluginBroker
pluginTomb tomb.Tomb
)
if len(args) != 2 {
printHelp(cmd)
return
}
id, err := strconv.Atoi(args[0])
if err != nil {
log.Fatalf("bad alert id %s", args[0])
}
if err := csConfig.LoadAPIClient(); err != nil {
log.Fatalf("loading api client: %s", err.Error())
}
if csConfig.API.Client == nil {
log.Fatalln("There is no configuration on 'api_client:'")
}
if csConfig.API.Client.Credentials == nil {
log.Fatalf("Please provide credentials for the API in '%s'", csConfig.API.Client.CredentialsFilePath)
}
apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL)
Client, err = apiclient.NewClient(&apiclient.Config{
MachineID: csConfig.API.Client.Credentials.Login,
Password: strfmt.Password(csConfig.API.Client.Credentials.Password),
UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
URL: apiURL,
VersionPrefix: "v1",
})
alert, _, err := Client.Alerts.GetByID(context.Background(), id)
if err != nil {
log.Fatalf("can't find alert with id %s: %s", args[0], err)
}
err = pluginBroker.Init(csConfig.PluginConfig, csConfig.API.Server.Profiles, csConfig.ConfigPaths)
if err != nil {
log.Fatalf("Can't initialize plugins: %s", err.Error())
}
pluginTomb.Go(func() error {
pluginBroker.Run(&pluginTomb)
fmt.Printf("\nreturned\n")
return nil
})
loop:
for {
select {
case pluginBroker.PluginChannel <- csplugin.ProfileAlert{
ProfileID: 1,
Alert: alert,
}:
break loop
default:
time.Sleep(50 * time.Millisecond)
log.Info("sleeping\n")
}
}
pluginTomb.Kill(errors.New("terminating"))
pluginTomb.Wait()
},
}
cmdNotifications.AddCommand(cmdNotificationsReinject)
return cmdNotifications
}

View file

@ -33,6 +33,10 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) {
return nil, errors.Wrap(err, "unable to run local API")
}
log.Info("initiated plugin broker")
apiserver := &apiserver.APIServer{
URL: "",
TLS: &csconfig.TLSCfg{},
}
apiServer.AttachPluginBroker(&pluginBroker)
}

View file

@ -152,7 +152,7 @@ func TestCreateAlertChannels(t *testing.T) {
if err != nil {
log.Fatalln(err.Error())
}
apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert)
apiServer.Controller.PluginChannel = make(chan csplugin.ProfileAlert)
apiServer.InitController()
loginResp, err := LoginToTestAPI(apiServer.router, config)
@ -166,7 +166,7 @@ func TestCreateAlertChannels(t *testing.T) {
wg.Add(1)
go func() {
pd = <-apiServer.controller.PluginChannel
pd = <-apiServer.Controller.PluginChannel
wg.Done()
}()

View file

@ -34,9 +34,9 @@ var (
type APIServer struct {
URL string
TLS *csconfig.TLSCfg
Controller *controllers.Controller
dbClient *database.Client
logFile string
controller *controllers.Controller
flushScheduler *gocron.Scheduler
router *gin.Engine
httpServer *http.Server
@ -230,7 +230,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
TLS: config.TLS,
logFile: logFile,
dbClient: dbClient,
controller: controller,
Controller: controller,
flushScheduler: flushScheduler,
router: router,
apic: apiClient,
@ -370,11 +370,10 @@ func (s *APIServer) Shutdown() error {
}
func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) {
s.controller.PluginChannel = broker.PluginChannel
s.Controller.PluginChannel = broker.PluginChannel
}
func (s *APIServer) InitController() error {
err := s.controller.Init()
if err != nil {
return errors.Wrap(err, "controller init")

View file

@ -96,10 +96,11 @@ func (pb *PluginBroker) Kill() {
}
}
func (pb *PluginBroker) Run(tomb *tomb.Tomb) {
func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) {
//we get signaled via the channel when notifications need to be delivered to plugin (via the watcher)
pb.watcher.Start(tomb)
pb.watcher.Start(&tomb.Tomb{})
for {
fmt.Printf("looping")
select {
case profileAlert := <-pb.PluginChannel:
pb.addProfileAlert(profileAlert)
@ -116,8 +117,25 @@ func (pb *PluginBroker) Run(tomb *tomb.Tomb) {
log.WithField("plugin:", pluginName).Error(err)
}
}()
case <-tomb.Dying():
case <-pluginTomb.Dying():
pb.watcher.tomb.Kill(errors.New("Terminating"))
loop:
for {
select {
case pluginName := <-pb.watcher.PluginEvents:
// this can be ran in goroutine, but then locks will be needed
pluginMutex.Lock()
log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName)
tmpAlerts := pb.alertsByPluginName[pluginName]
pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0)
pluginMutex.Unlock()
if err := pb.pushNotificationsToPlugin(pluginName, tmpAlerts); err != nil {
log.WithField("plugin:", pluginName).Error(err)
}
case <-pb.watcher.tomb.Dead():
break loop
}
}
log.Info("killing all plugins")
pb.Kill()
return
@ -133,7 +151,10 @@ func (pb *PluginBroker) addProfileAlert(profileAlert ProfileAlert) {
pluginMutex.Lock()
pb.alertsByPluginName[pluginName] = append(pb.alertsByPluginName[pluginName], profileAlert.Alert)
pluginMutex.Unlock()
pb.watcher.Inserts <- pluginName
if _, ok := pb.watcher.PluginConfigByName[pluginName]; ok {
curr, _ := pb.watcher.AlertCountByPluginName.Get(pluginName)
pb.watcher.AlertCountByPluginName.Set(pluginName, curr+1)
}
}
}
func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool {

View file

@ -73,11 +73,6 @@ func (pw *PluginWatcher) Start(tomb *tomb.Tomb) {
return nil
})
}
pw.tomb.Go(func() error {
pw.watchPluginAlertCounts()
return nil
})
}
func (pw *PluginWatcher) watchPluginTicker(pluginName string) {
@ -139,21 +134,8 @@ func (pw *PluginWatcher) watchPluginTicker(pluginName string) {
}
case <-pw.tomb.Dying():
ticker.Stop()
return
}
}
}
func (pw *PluginWatcher) watchPluginAlertCounts() {
for {
select {
case pluginName := <-pw.Inserts:
//we only "count" pending alerts, and watchPluginTicker is actually going to send it
if _, ok := pw.PluginConfigByName[pluginName]; ok {
curr, _ := pw.AlertCountByPluginName.Get(pluginName)
pw.AlertCountByPluginName.Set(pluginName, curr+1)
}
case <-pw.tomb.Dying():
pw.PluginEvents <- pluginName
log.Tracef("sending alerts to %s", pluginName)
return
}
}