* add a way to test notifications
* fix an antipattern in broker system with tomb * kill an unneeded goroutine
This commit is contained in:
parent
18030e6c58
commit
4efedbed34
|
@ -1,20 +1,28 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/csv"
|
"encoding/csv"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
|
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
|
||||||
|
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
|
||||||
|
"github.com/go-openapi/strfmt"
|
||||||
"github.com/olekukonko/tablewriter"
|
"github.com/olekukonko/tablewriter"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"gopkg.in/tomb.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type NotificationsCfg struct {
|
type NotificationsCfg struct {
|
||||||
|
@ -135,6 +143,81 @@ func NewNotificationsCmd() *cobra.Command {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
cmdNotifications.AddCommand(cmdNotificationsInspect)
|
cmdNotifications.AddCommand(cmdNotificationsInspect)
|
||||||
|
var cmdNotificationsReinject = &cobra.Command{
|
||||||
|
Use: "reinject",
|
||||||
|
Short: "reinject alerts into notifications system",
|
||||||
|
Long: `Reinject alerts into notifications system`,
|
||||||
|
Example: `cscli notifications reinject <alert_id> <plugin_name>`,
|
||||||
|
Args: cobra.ExactArgs(2),
|
||||||
|
DisableAutoGenTag: true,
|
||||||
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
var (
|
||||||
|
pluginBroker csplugin.PluginBroker
|
||||||
|
pluginTomb tomb.Tomb
|
||||||
|
)
|
||||||
|
if len(args) != 2 {
|
||||||
|
printHelp(cmd)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id, err := strconv.Atoi(args[0])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("bad alert id %s", args[0])
|
||||||
|
}
|
||||||
|
if err := csConfig.LoadAPIClient(); err != nil {
|
||||||
|
log.Fatalf("loading api client: %s", err.Error())
|
||||||
|
}
|
||||||
|
if csConfig.API.Client == nil {
|
||||||
|
log.Fatalln("There is no configuration on 'api_client:'")
|
||||||
|
}
|
||||||
|
if csConfig.API.Client.Credentials == nil {
|
||||||
|
log.Fatalf("Please provide credentials for the API in '%s'", csConfig.API.Client.CredentialsFilePath)
|
||||||
|
}
|
||||||
|
apiURL, err := url.Parse(csConfig.API.Client.Credentials.URL)
|
||||||
|
Client, err = apiclient.NewClient(&apiclient.Config{
|
||||||
|
MachineID: csConfig.API.Client.Credentials.Login,
|
||||||
|
Password: strfmt.Password(csConfig.API.Client.Credentials.Password),
|
||||||
|
UserAgent: fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
|
||||||
|
URL: apiURL,
|
||||||
|
VersionPrefix: "v1",
|
||||||
|
})
|
||||||
|
|
||||||
|
alert, _, err := Client.Alerts.GetByID(context.Background(), id)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("can't find alert with id %s: %s", args[0], err)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pluginBroker.Init(csConfig.PluginConfig, csConfig.API.Server.Profiles, csConfig.ConfigPaths)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Can't initialize plugins: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
pluginTomb.Go(func() error {
|
||||||
|
pluginBroker.Run(&pluginTomb)
|
||||||
|
fmt.Printf("\nreturned\n")
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case pluginBroker.PluginChannel <- csplugin.ProfileAlert{
|
||||||
|
ProfileID: 1,
|
||||||
|
Alert: alert,
|
||||||
|
}:
|
||||||
|
break loop
|
||||||
|
default:
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
log.Info("sleeping\n")
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pluginTomb.Kill(errors.New("terminating"))
|
||||||
|
pluginTomb.Wait()
|
||||||
|
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cmdNotifications.AddCommand(cmdNotificationsReinject)
|
||||||
return cmdNotifications
|
return cmdNotifications
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,10 @@ func initAPIServer(cConfig *csconfig.Config) (*apiserver.APIServer, error) {
|
||||||
return nil, fmt.Errorf("unable to run local API: %s", err)
|
return nil, fmt.Errorf("unable to run local API: %s", err)
|
||||||
}
|
}
|
||||||
log.Info("initiated plugin broker")
|
log.Info("initiated plugin broker")
|
||||||
|
apiserver := &apiserver.APIServer{
|
||||||
|
URL: "",
|
||||||
|
TLS: &csconfig.TLSCfg{},
|
||||||
|
}
|
||||||
apiServer.AttachPluginBroker(&pluginBroker)
|
apiServer.AttachPluginBroker(&pluginBroker)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -146,7 +146,7 @@ func TestCreateAlertChannels(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln(err.Error())
|
log.Fatalln(err.Error())
|
||||||
}
|
}
|
||||||
apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert)
|
apiServer.Controller.PluginChannel = make(chan csplugin.ProfileAlert)
|
||||||
apiServer.InitController()
|
apiServer.InitController()
|
||||||
|
|
||||||
loginResp, err := LoginToTestAPI(apiServer.router, config)
|
loginResp, err := LoginToTestAPI(apiServer.router, config)
|
||||||
|
@ -160,7 +160,7 @@ func TestCreateAlertChannels(t *testing.T) {
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
pd = <-apiServer.controller.PluginChannel
|
pd = <-apiServer.Controller.PluginChannel
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
@ -30,9 +30,9 @@ var (
|
||||||
type APIServer struct {
|
type APIServer struct {
|
||||||
URL string
|
URL string
|
||||||
TLS *csconfig.TLSCfg
|
TLS *csconfig.TLSCfg
|
||||||
|
Controller *controllers.Controller
|
||||||
dbClient *database.Client
|
dbClient *database.Client
|
||||||
logFile string
|
logFile string
|
||||||
controller *controllers.Controller
|
|
||||||
flushScheduler *gocron.Scheduler
|
flushScheduler *gocron.Scheduler
|
||||||
router *gin.Engine
|
router *gin.Engine
|
||||||
httpServer *http.Server
|
httpServer *http.Server
|
||||||
|
@ -227,7 +227,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
|
||||||
TLS: config.TLS,
|
TLS: config.TLS,
|
||||||
logFile: logFile,
|
logFile: logFile,
|
||||||
dbClient: dbClient,
|
dbClient: dbClient,
|
||||||
controller: controller,
|
Controller: controller,
|
||||||
flushScheduler: flushScheduler,
|
flushScheduler: flushScheduler,
|
||||||
router: router,
|
router: router,
|
||||||
apic: apiClient,
|
apic: apiClient,
|
||||||
|
@ -323,10 +323,10 @@ func (s *APIServer) Shutdown() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) {
|
func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) {
|
||||||
s.controller.PluginChannel = broker.PluginChannel
|
s.Controller.PluginChannel = broker.PluginChannel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *APIServer) InitController() error {
|
func (s *APIServer) InitController() error {
|
||||||
err := s.controller.Init()
|
err := s.Controller.Init()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,10 +96,11 @@ func (pb *PluginBroker) Kill() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pb *PluginBroker) Run(tomb *tomb.Tomb) {
|
func (pb *PluginBroker) Run(pluginTomb *tomb.Tomb) {
|
||||||
//we get signaled via the channel when notifications need to be delivered to plugin (via the watcher)
|
//we get signaled via the channel when notifications need to be delivered to plugin (via the watcher)
|
||||||
pb.watcher.Start(tomb)
|
pb.watcher.Start(&tomb.Tomb{})
|
||||||
for {
|
for {
|
||||||
|
fmt.Printf("looping")
|
||||||
select {
|
select {
|
||||||
case profileAlert := <-pb.PluginChannel:
|
case profileAlert := <-pb.PluginChannel:
|
||||||
pb.addProfileAlert(profileAlert)
|
pb.addProfileAlert(profileAlert)
|
||||||
|
@ -116,8 +117,25 @@ func (pb *PluginBroker) Run(tomb *tomb.Tomb) {
|
||||||
log.WithField("plugin:", pluginName).Error(err)
|
log.WithField("plugin:", pluginName).Error(err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
case <-pluginTomb.Dying():
|
||||||
case <-tomb.Dying():
|
pb.watcher.tomb.Kill(errors.New("Terminating"))
|
||||||
|
loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case pluginName := <-pb.watcher.PluginEvents:
|
||||||
|
// this can be ran in goroutine, but then locks will be needed
|
||||||
|
pluginMutex.Lock()
|
||||||
|
log.Tracef("going to deliver %d alerts to plugin %s", len(pb.alertsByPluginName[pluginName]), pluginName)
|
||||||
|
tmpAlerts := pb.alertsByPluginName[pluginName]
|
||||||
|
pb.alertsByPluginName[pluginName] = make([]*models.Alert, 0)
|
||||||
|
pluginMutex.Unlock()
|
||||||
|
if err := pb.pushNotificationsToPlugin(pluginName, tmpAlerts); err != nil {
|
||||||
|
log.WithField("plugin:", pluginName).Error(err)
|
||||||
|
}
|
||||||
|
case <-pb.watcher.tomb.Dead():
|
||||||
|
break loop
|
||||||
|
}
|
||||||
|
}
|
||||||
log.Info("killing all plugins")
|
log.Info("killing all plugins")
|
||||||
pb.Kill()
|
pb.Kill()
|
||||||
return
|
return
|
||||||
|
@ -133,7 +151,10 @@ func (pb *PluginBroker) addProfileAlert(profileAlert ProfileAlert) {
|
||||||
pluginMutex.Lock()
|
pluginMutex.Lock()
|
||||||
pb.alertsByPluginName[pluginName] = append(pb.alertsByPluginName[pluginName], profileAlert.Alert)
|
pb.alertsByPluginName[pluginName] = append(pb.alertsByPluginName[pluginName], profileAlert.Alert)
|
||||||
pluginMutex.Unlock()
|
pluginMutex.Unlock()
|
||||||
pb.watcher.Inserts <- pluginName
|
if _, ok := pb.watcher.PluginConfigByName[pluginName]; ok {
|
||||||
|
curr, _ := pb.watcher.AlertCountByPluginName.Get(pluginName)
|
||||||
|
pb.watcher.AlertCountByPluginName.Set(pluginName, curr+1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool {
|
func (pb *PluginBroker) profilesContainPlugin(pluginName string) bool {
|
||||||
|
|
|
@ -73,11 +73,6 @@ func (pw *PluginWatcher) Start(tomb *tomb.Tomb) {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pw.tomb.Go(func() error {
|
|
||||||
pw.watchPluginAlertCounts()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pw *PluginWatcher) watchPluginTicker(pluginName string) {
|
func (pw *PluginWatcher) watchPluginTicker(pluginName string) {
|
||||||
|
@ -139,21 +134,8 @@ func (pw *PluginWatcher) watchPluginTicker(pluginName string) {
|
||||||
}
|
}
|
||||||
case <-pw.tomb.Dying():
|
case <-pw.tomb.Dying():
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
pw.PluginEvents <- pluginName
|
||||||
}
|
log.Tracef("sending alerts to %s", pluginName)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pw *PluginWatcher) watchPluginAlertCounts() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case pluginName := <-pw.Inserts:
|
|
||||||
//we only "count" pending alerts, and watchPluginTicker is actually going to send it
|
|
||||||
if _, ok := pw.PluginConfigByName[pluginName]; ok {
|
|
||||||
curr, _ := pw.AlertCountByPluginName.Get(pluginName)
|
|
||||||
pw.AlertCountByPluginName.Set(pluginName, curr+1)
|
|
||||||
}
|
|
||||||
case <-pw.tomb.Dying():
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue