From edced6818a1bdd9fba75b1c3a91cbeb9be782280 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Fri, 30 Sep 2022 16:01:42 +0200 Subject: [PATCH] cleanup + fix flaky tests in file_test.go, apic_test.go (#1773) --- pkg/acquisition/modules/file/file_test.go | 442 +++++++++++----------- pkg/apiserver/apic.go | 43 +-- pkg/apiserver/apic_test.go | 276 +++++++------- pkg/apiserver/apiserver.go | 5 +- pkg/cstest/utils.go | 14 + 5 files changed, 391 insertions(+), 389 deletions(-) diff --git a/pkg/acquisition/modules/file/file_test.go b/pkg/acquisition/modules/file/file_test.go index 3404453d4..5e88c1cfb 100644 --- a/pkg/acquisition/modules/file/file_test.go +++ b/pkg/acquisition/modules/file/file_test.go @@ -1,4 +1,4 @@ -package fileacquisition +package fileacquisition_test import ( "fmt" @@ -7,32 +7,39 @@ import ( "testing" "time" + fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" "github.com/crowdsecurity/crowdsec/pkg/cstest" "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" ) func TestBadConfiguration(t *testing.T) { tests := []struct { + name string config string expectedErr string }{ { - config: `foobar: asd.log`, + name: "extra configuration key", + config: "foobar: asd.log", expectedErr: "line 1: field foobar not found in type fileacquisition.FileConfiguration", }, { - config: `mode: tail`, + name: "missing filenames", + config: "mode: tail", expectedErr: "no filename or filenames configuration provided", }, { + name: "glob syntax error", config: `filename: "[asd-.log"`, expectedErr: "Glob failure: syntax error in pattern", }, { + name: "bad exclude regexp", config: `filenames: ["asd.log"] exclude_regexps: ["as[a-$d"]`, expectedErr: "Could not compile regexp as", @@ -42,20 +49,24 @@ exclude_regexps: ["as[a-$d"]`, subLogger := log.WithFields(log.Fields{ "type": "file", }) - for _, test := range tests { - f := FileSource{} - err := f.Configure([]byte(test.config), subLogger) - assert.Contains(t, err.Error(), test.expectedErr) + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + f := fileacquisition.FileSource{} + err := f.Configure([]byte(tc.config), subLogger) + cstest.RequireErrorContains(t, err, tc.expectedErr) + }) } } func TestConfigureDSN(t *testing.T) { - var file string - if runtime.GOOS != "windows" { - file = "/etc/passwd" - } else { - file = "C:\\Windows\\System32\\drivers\\etc\\hosts" + file := "/etc/passwd" + + if runtime.GOOS == "windows" { + file = `C:\Windows\System32\drivers\etc\hosts` } + tests := []struct { dsn string expectedErr string @@ -69,37 +80,41 @@ func TestConfigureDSN(t *testing.T) { expectedErr: "empty file:// DSN", }, { - dsn: fmt.Sprintf("file://%s?log_level=warn", file), - expectedErr: "", + dsn: fmt.Sprintf("file://%s?log_level=warn", file), }, { dsn: fmt.Sprintf("file://%s?log_level=foobar", file), expectedErr: "unknown level foobar: not a valid logrus Level:", }, } + subLogger := log.WithFields(log.Fields{ "type": "file", }) - for _, test := range tests { - f := FileSource{} - err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger) - cstest.AssertErrorContains(t, err, test.expectedErr) + + for _, tc := range tests { + tc := tc + t.Run(tc.dsn, func(t *testing.T) { + f := fileacquisition.FileSource{} + err := f.ConfigureByDSN(tc.dsn, map[string]string{"type": "testtype"}, subLogger) + cstest.RequireErrorContains(t, err, tc.expectedErr) + }) } } func TestOneShot(t *testing.T) { - var permDeniedFile string - var permDeniedError string - if runtime.GOOS != "windows" { - permDeniedFile = "/etc/shadow" - permDeniedError = "failed opening /etc/shadow: open /etc/shadow: permission denied" - } else { - //Technically, this is not a permission denied error, but we just want to test what happens - //if we do not have access to the file - permDeniedFile = "C:\\Windows\\System32\\config\\SAM" - permDeniedError = "failed opening C:\\Windows\\System32\\config\\SAM: open C:\\Windows\\System32\\config\\SAM: The process cannot access the file because it is being used by another process." + permDeniedFile := "/etc/shadow" + permDeniedError := "failed opening /etc/shadow: open /etc/shadow: permission denied" + + if runtime.GOOS == "windows" { + // Technically, this is not a permission denied error, but we just want to test what happens + // if we do not have access to the file + permDeniedFile = `C:\Windows\System32\config\SAM` + permDeniedError = `failed opening C:\Windows\System32\config\SAM: open C:\Windows\System32\config\SAM: The process cannot access the file because it is being used by another process.` } + tests := []struct { + name string config string expectedConfigErr string expectedErr string @@ -111,76 +126,68 @@ func TestOneShot(t *testing.T) { teardown func() }{ { + name: "permission denied", config: fmt.Sprintf(` mode: cat filename: %s`, permDeniedFile), - expectedConfigErr: "", - expectedErr: permDeniedError, - expectedOutput: "", - logLevel: log.WarnLevel, - expectedLines: 0, + expectedErr: permDeniedError, + logLevel: log.WarnLevel, + expectedLines: 0, }, { + name: "ignored directory", config: ` mode: cat filename: /`, - expectedConfigErr: "", - expectedErr: "", - expectedOutput: "/ is a directory, ignoring it", - logLevel: log.WarnLevel, - expectedLines: 0, + expectedOutput: "/ is a directory, ignoring it", + logLevel: log.WarnLevel, + expectedLines: 0, }, { + name: "glob syntax error", config: ` mode: cat filename: "[*-.log"`, expectedConfigErr: "Glob failure: syntax error in pattern", - expectedErr: "", - expectedOutput: "", logLevel: log.WarnLevel, expectedLines: 0, }, { + name: "no matching files", config: ` mode: cat filename: /do/not/exist`, - expectedConfigErr: "", - expectedErr: "", - expectedOutput: "No matching files for pattern /do/not/exist", - logLevel: log.WarnLevel, - expectedLines: 0, + expectedOutput: "No matching files for pattern /do/not/exist", + logLevel: log.WarnLevel, + expectedLines: 0, }, { + name: "test.log", config: ` mode: cat filename: test_files/test.log`, - expectedConfigErr: "", - expectedErr: "", - expectedOutput: "", - expectedLines: 5, - logLevel: log.WarnLevel, + expectedLines: 5, + logLevel: log.WarnLevel, }, { + name: "test.log.gz", config: ` mode: cat filename: test_files/test.log.gz`, - expectedConfigErr: "", - expectedErr: "", - expectedOutput: "", - expectedLines: 5, - logLevel: log.WarnLevel, + expectedLines: 5, + logLevel: log.WarnLevel, }, { + name: "unexpected end of gzip stream", config: ` mode: cat filename: test_files/bad.gz`, - expectedConfigErr: "", - expectedErr: "failed to read gz test_files/bad.gz: unexpected EOF", - expectedOutput: "", - expectedLines: 0, - logLevel: log.WarnLevel, + expectedErr: "failed to read gz test_files/bad.gz: unexpected EOF", + expectedLines: 0, + logLevel: log.WarnLevel, }, { + name: "deleted file", config: ` mode: cat filename: test_files/test_delete.log`, @@ -195,77 +202,84 @@ filename: test_files/test_delete.log`, }, } - for _, ts := range tests { - logger, hook := test.NewNullLogger() - logger.SetLevel(ts.logLevel) - subLogger := logger.WithFields(log.Fields{ - "type": "file", - }) - tomb := tomb.Tomb{} - out := make(chan types.Event) - f := FileSource{} - if ts.setup != nil { - ts.setup() - } - err := f.Configure([]byte(ts.config), subLogger) - cstest.AssertErrorContains(t, err, ts.expectedConfigErr) - if err != nil { - continue - } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + logger, hook := test.NewNullLogger() + logger.SetLevel(tc.logLevel) - if ts.afterConfigure != nil { - ts.afterConfigure() - } - actualLines := 0 - if ts.expectedLines != 0 { - go func() { - READLOOP: - for { - select { - case <-out: - actualLines++ - case <-time.After(1 * time.Second): - break READLOOP + subLogger := logger.WithFields(log.Fields{ + "type": "file", + }) + + tomb := tomb.Tomb{} + out := make(chan types.Event) + f := fileacquisition.FileSource{} + + if tc.setup != nil { + tc.setup() + } + + err := f.Configure([]byte(tc.config), subLogger) + cstest.RequireErrorContains(t, err, tc.expectedConfigErr) + if tc.expectedConfigErr != "" { + return + } + + if tc.afterConfigure != nil { + tc.afterConfigure() + } + + actualLines := 0 + if tc.expectedLines != 0 { + go func() { + for { + select { + case <-out: + actualLines++ + case <-time.After(2 * time.Second): + return + } } - } - }() - } - err = f.OneShotAcquisition(out, &tomb) - cstest.AssertErrorContains(t, err, ts.expectedErr) + }() + } - if ts.expectedLines != 0 { - assert.Equal(t, actualLines, ts.expectedLines) - } - if ts.expectedOutput != "" { - assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) - hook.Reset() - } - if ts.teardown != nil { - ts.teardown() - } + err = f.OneShotAcquisition(out, &tomb) + cstest.RequireErrorContains(t, err, tc.expectedErr) + + if tc.expectedLines != 0 { + assert.Equal(t, tc.expectedLines, actualLines) + } + + if tc.expectedOutput != "" { + assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput) + hook.Reset() + } + if tc.teardown != nil { + tc.teardown() + } + }) } } func TestLiveAcquisition(t *testing.T) { - var permDeniedFile string - var permDeniedError string - var testPattern string - if runtime.GOOS != "windows" { - permDeniedFile = "/etc/shadow" - permDeniedError = "unable to read /etc/shadow : open /etc/shadow: permission denied" - testPattern = "test_files/*.log" - } else { - //Technically, this is not a permission denied error, but we just want to test what happens - //if we do not have access to the file - permDeniedFile = "C:\\Windows\\System32\\config\\SAM" - permDeniedError = "unable to read C:\\Windows\\System32\\config\\SAM : open C:\\Windows\\System32\\config\\SAM: The process cannot access the file because it is being used by another process" - testPattern = "test_files\\\\*.log" // the \ must be escaped twice: once for the string, once for the yaml config + permDeniedFile := "/etc/shadow" + permDeniedError := "unable to read /etc/shadow : open /etc/shadow: permission denied" + testPattern := "test_files/*.log" + + if runtime.GOOS == "windows" { + // Technically, this is not a permission denied error, but we just want to test what happens + // if we do not have access to the file + permDeniedFile = `C:\Windows\System32\config\SAM` + permDeniedError = `unable to read C:\Windows\System32\config\SAM : open C:\Windows\System32\config\SAM: The process cannot access the file because it is being used by another process` + testPattern = `test_files\\*.log` // the \ must be escaped for the yaml config } + tests := []struct { + name string config string expectedErr string expectedOutput string - name string expectedLines int logLevel log.Level setup func() @@ -276,7 +290,6 @@ func TestLiveAcquisition(t *testing.T) { config: fmt.Sprintf(` mode: tail filename: %s`, permDeniedFile), - expectedErr: "", expectedOutput: permDeniedError, logLevel: log.InfoLevel, expectedLines: 0, @@ -286,7 +299,6 @@ filename: %s`, permDeniedFile), config: ` mode: tail filename: /`, - expectedErr: "", expectedOutput: "/ is a directory, ignoring it", logLevel: log.WarnLevel, expectedLines: 0, @@ -296,7 +308,6 @@ filename: /`, config: ` mode: tail filename: /do/not/exist`, - expectedErr: "", expectedOutput: "No matching files for pattern /do/not/exist", logLevel: log.WarnLevel, expectedLines: 0, @@ -308,11 +319,9 @@ mode: tail filenames: - %s force_inotify: true`, testPattern), - expectedErr: "", - expectedOutput: "", - expectedLines: 5, - logLevel: log.DebugLevel, - name: "basicGlob", + expectedLines: 5, + logLevel: log.DebugLevel, + name: "basicGlob", }, { config: fmt.Sprintf(` @@ -320,11 +329,9 @@ mode: tail filenames: - %s force_inotify: true`, testPattern), - expectedErr: "", - expectedOutput: "", - expectedLines: 0, - logLevel: log.DebugLevel, - name: "GlobInotify", + expectedLines: 0, + logLevel: log.DebugLevel, + name: "GlobInotify", afterConfigure: func() { f, _ := os.Create("test_files/a.log") f.Close() @@ -338,19 +345,17 @@ mode: tail filenames: - %s force_inotify: true`, testPattern), - expectedErr: "", - expectedOutput: "", - expectedLines: 5, - logLevel: log.DebugLevel, - name: "GlobInotifyChmod", + expectedLines: 5, + logLevel: log.DebugLevel, + name: "GlobInotifyChmod", afterConfigure: func() { f, _ := os.Create("test_files/a.log") f.Close() time.Sleep(1 * time.Second) - os.Chmod("test_files/a.log", 0000) + os.Chmod("test_files/a.log", 0o000) }, teardown: func() { - os.Chmod("test_files/a.log", 0644) + os.Chmod("test_files/a.log", 0o644) os.Remove("test_files/a.log") }, }, @@ -360,13 +365,11 @@ mode: tail filenames: - %s force_inotify: true`, testPattern), - expectedErr: "", - expectedOutput: "", - expectedLines: 5, - logLevel: log.DebugLevel, - name: "InotifyMkDir", + expectedLines: 5, + logLevel: log.DebugLevel, + name: "InotifyMkDir", afterConfigure: func() { - os.Mkdir("test_files/pouet/", 0700) + os.Mkdir("test_files/pouet/", 0o700) }, teardown: func() { os.Remove("test_files/pouet/") @@ -374,101 +377,112 @@ force_inotify: true`, testPattern), }, } - for _, ts := range tests { - t.Logf("test: %s", ts.name) - logger, hook := test.NewNullLogger() - logger.SetLevel(ts.logLevel) - subLogger := logger.WithFields(log.Fields{ - "type": "file", - }) - tomb := tomb.Tomb{} - out := make(chan types.Event) - f := FileSource{} - if ts.setup != nil { - ts.setup() - } - err := f.Configure([]byte(ts.config), subLogger) - if err != nil { - t.Fatalf("Unexpected error : %s", err) - } - if ts.afterConfigure != nil { - ts.afterConfigure() - } - actualLines := 0 - if ts.expectedLines != 0 { - go func() { - READLOOP: - for { - select { - case <-out: - actualLines++ - case <-time.After(2 * time.Second): - break READLOOP + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + logger, hook := test.NewNullLogger() + logger.SetLevel(tc.logLevel) + + subLogger := logger.WithFields(log.Fields{ + "type": "file", + }) + + tomb := tomb.Tomb{} + out := make(chan types.Event) + + f := fileacquisition.FileSource{} + + if tc.setup != nil { + tc.setup() + } + + err := f.Configure([]byte(tc.config), subLogger) + require.NoError(t, err) + + if tc.afterConfigure != nil { + tc.afterConfigure() + } + + actualLines := 0 + if tc.expectedLines != 0 { + go func() { + for { + select { + case <-out: + actualLines++ + case <-time.After(2 * time.Second): + return + } + } + }() + } + + err = f.StreamingAcquisition(out, &tomb) + cstest.RequireErrorContains(t, err, tc.expectedErr) + + if tc.expectedLines != 0 { + fd, err := os.Create("test_files/stream.log") + if err != nil { + t.Fatalf("could not create test file : %s", err) + } + + for i := 0; i < 5; i++ { + _, err = fmt.Fprintf(fd, "%d\n", i) + if err != nil { + t.Fatalf("could not write test file : %s", err) + os.Remove("test_files/stream.log") } } - }() - } - err = f.StreamingAcquisition(out, &tomb) - cstest.AssertErrorContains(t, err, ts.expectedErr) - if ts.expectedLines != 0 { - fd, err := os.Create("test_files/stream.log") - if err != nil { - t.Fatalf("could not create test file : %s", err) + fd.Close() + // we sleep to make sure we detect the new file + time.Sleep(1 * time.Second) + os.Remove("test_files/stream.log") + assert.Equal(t, tc.expectedLines, actualLines) } - for i := 0; i < 5; i++ { - _, err = fd.WriteString(fmt.Sprintf("%d\n", i)) - if err != nil { - t.Fatalf("could not write test file : %s", err) - os.Remove("test_files/stream.log") + + if tc.expectedOutput != "" { + if hook.LastEntry() == nil { + t.Fatalf("expected output %s, but got nothing", tc.expectedOutput) } + + assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput) + hook.Reset() } - fd.Close() - //we sleep to make sure we detect the new file - time.Sleep(1 * time.Second) - os.Remove("test_files/stream.log") - assert.Equal(t, ts.expectedLines, actualLines) - } - if ts.expectedOutput != "" { - if hook.LastEntry() == nil { - t.Fatalf("expected output %s, but got nothing", ts.expectedOutput) + if tc.teardown != nil { + tc.teardown() } - assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) - hook.Reset() - } - if ts.teardown != nil { - ts.teardown() - } - - tomb.Kill(nil) + tomb.Kill(nil) + }) } } func TestExclusion(t *testing.T) { - config := `filenames: ["test_files/*.log*"] exclude_regexps: ["\\.gz$"]` logger, hook := test.NewNullLogger() - //logger.SetLevel(ts.logLevel) + // logger.SetLevel(ts.logLevel) subLogger := logger.WithFields(log.Fields{ "type": "file", }) - f := FileSource{} - err := f.Configure([]byte(config), subLogger) - if err != nil { + + f := fileacquisition.FileSource{} + if err := f.Configure([]byte(config), subLogger); err != nil { subLogger.Fatalf("unexpected error: %s", err) } - var expectedLogOutput string + + expectedLogOutput := "Skipping file test_files/test.log.gz as it matches exclude pattern" + if runtime.GOOS == "windows" { - expectedLogOutput = "Skipping file test_files\\test.log.gz as it matches exclude pattern \\.gz" - } else { - expectedLogOutput = "Skipping file test_files/test.log.gz as it matches exclude pattern" + expectedLogOutput = `Skipping file test_files\test.log.gz as it matches exclude pattern \.gz` } + if hook.LastEntry() == nil { t.Fatalf("expected output %s, but got nothing", expectedLogOutput) } + assert.Contains(t, hook.LastEntry().Message, expectedLogOutput) hook.Reset() } diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 1adc51afe..586ad56b7 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -400,7 +400,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio return alerts } -//we receive only one list of decisions, that we need to break-up : +// we receive only one list of decisions, that we need to break-up : // one alert for "community blocklist" // one alert per list we're subscribed to func (a *apic) PullTop() error { @@ -432,7 +432,7 @@ func (a *apic) PullTop() error { return nil } - //we receive only one list of decisions, that we need to break-up : + // we receive only one list of decisions, that we need to break-up : // one alert for "community blocklist" // one alert per list we're subscribed to alertsFromCapi := createAlertsForDecisions(data.New) @@ -541,37 +541,32 @@ func (a *apic) GetMetrics() (*models.Metrics, error) { return metric, nil } -func (a *apic) SendMetrics() error { +func (a *apic) SendMetrics(stop chan (bool)) { defer types.CatchPanic("lapi/metricsToAPIC") - metrics, err := a.GetMetrics() - if err != nil { - log.Errorf("unable to get metrics (%s), will retry", err) - } - _, _, err = a.apiClient.Metrics.Add(context.Background(), metrics) - if err != nil { - log.Errorf("unable to send metrics (%s), will retry", err) - } - log.Infof("capi metrics: metrics sent successfully") - log.Infof("Start send metrics to CrowdSec Central API (interval: %s)", MetricsInterval) + log.Infof("Start send metrics to CrowdSec Central API (interval: %s)", a.metricsInterval) ticker := time.NewTicker(a.metricsInterval) for { + metrics, err := a.GetMetrics() + if err != nil { + log.Errorf("unable to get metrics (%s), will retry", err) + } + _, _, err = a.apiClient.Metrics.Add(context.Background(), metrics) + if err != nil { + log.Errorf("capi metrics: failed: %s", err) + } else { + log.Infof("capi metrics: metrics sent successfully") + } + select { + case <-stop: + return case <-ticker.C: - metrics, err := a.GetMetrics() - if err != nil { - log.Errorf("unable to get metrics (%s), will retry", err) - } - _, _, err = a.apiClient.Metrics.Add(context.Background(), metrics) - if err != nil { - log.Errorf("capi metrics: failed: %s", err) - } else { - log.Infof("capi metrics: metrics sent successfully") - } + continue case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others? a.pullTomb.Kill(nil) a.pushTomb.Kill(nil) - return nil + return } } } diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 6b4841ee0..2b20b2db0 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -8,13 +8,13 @@ import ( "net/url" "os" "reflect" - "sort" "sync" "testing" "time" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/csconfig" + "github.com/crowdsecurity/crowdsec/pkg/cstest" "github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" @@ -24,23 +24,20 @@ import ( "github.com/jarcoal/httpmock" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" ) func getDBClient(t *testing.T) *database.Client { t.Helper() dbPath, err := os.CreateTemp("", "*sqlite") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) dbClient, err := database.NewClient(&csconfig.DatabaseCfg{ Type: "sqlite", DbName: "crowdsec", DbPath: dbPath.Name(), }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) return dbClient } @@ -98,11 +95,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) { func TestAPICCAPIPullIsOld(t *testing.T) { api := getAPIC(t) - isOld, err := api.CAPIPullIsOld() - if err != nil { - t.Fatal(err) - } + isOld, err := api.CAPIPullIsOld() + require.NoError(t, err) assert.True(t, isOld) decision := api.dbClient.Ent.Decision.Create(). @@ -123,16 +118,13 @@ func TestAPICCAPIPullIsOld(t *testing.T) { SaveX(context.Background()) isOld, err = api.CAPIPullIsOld() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.False(t, isOld) } func TestAPICFetchScenariosListFromDB(t *testing.T) { - api := getAPIC(t) - testCases := []struct { + tests := []struct { name string machineIDsWithScenarios map[string]string expectedScenarios []string @@ -154,8 +146,10 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { }, } - for _, tc := range testCases { + for _, tc := range tests { + tc := tc t.Run(tc.name, func(t *testing.T) { + api := getAPIC(t) for machineID, scenarios := range tc.machineIDsWithScenarios { api.dbClient.Ent.Machine.Create(). SetMachineId(machineID). @@ -164,17 +158,14 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { SetScenarios(scenarios). ExecX(context.Background()) } + scenarios, err := api.FetchScenariosListFromDB() for machineID := range tc.machineIDsWithScenarios { api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background()) } - if err != nil { - t.Fatal(err) - } else { - sort.Strings(scenarios) - sort.Strings(tc.expectedScenarios) - assert.Equal(t, scenarios, tc.expectedScenarios) - } + require.NoError(t, err) + + assert.ElementsMatch(t, tc.expectedScenarios, scenarios) }) } @@ -196,11 +187,10 @@ func TestNewAPIC(t *testing.T) { consoleConfig *csconfig.ConsoleConfig } tests := []struct { - name string - args args - wantErr bool - errorContains string - action func() + name string + args args + expectedErr string + action func() }{ { name: "simple", @@ -217,20 +207,16 @@ func TestNewAPIC(t *testing.T) { dbClient: getDBClient(t), consoleConfig: LoadTestConfig().API.Server.ConsoleConfig, }, - wantErr: true, - errorContains: "first path segment in URL cannot contain colon", + expectedErr: "first path segment in URL cannot contain colon", }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { setConfig() - tt.action() - _, err := NewAPIC(testConfig, tt.args.dbClient, tt.args.consoleConfig) - if tt.wantErr { - assert.ErrorContains(t, err, tt.errorContains) - } else { - assert.NoError(t, err) - } + tc.action() + _, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig) + cstest.RequireErrorContains(t, err, tc.expectedErr) }) } } @@ -268,17 +254,16 @@ func TestAPICHandleDeletedDecisions(t *testing.T) { }}, deleteCounters) assert.NoError(t, err) - assert.Equal(t, nbDeleted, 2) - assert.Equal(t, deleteCounters[SCOPE_CAPI]["all"], 2) + assert.Equal(t, 2, nbDeleted) + assert.Equal(t, 2, deleteCounters[SCOPE_CAPI]["all"]) } func TestAPICGetMetrics(t *testing.T) { - api := getAPIC(t) - cleanUp := func() { + cleanUp := func(api *apic) { api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) } - testCases := []struct { + tests := []struct { name string machineIDs []string bouncers []string @@ -322,11 +307,13 @@ func TestAPICGetMetrics(t *testing.T) { }, }, } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - cleanUp() - for i, machineID := range testCase.machineIDs { - api.dbClient.Ent.Machine.Create(). + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + apiClient := getAPIC(t) + cleanUp(apiClient) + for i, machineID := range tc.machineIDs { + apiClient.dbClient.Ent.Machine.Create(). SetMachineId(machineID). SetPassword(testPassword.String()). SetIpAddress(fmt.Sprintf("1.2.3.%d", i)). @@ -336,8 +323,8 @@ func TestAPICGetMetrics(t *testing.T) { ExecX(context.Background()) } - for i, bouncerName := range testCase.bouncers { - api.dbClient.Ent.Bouncer.Create(). + for i, bouncerName := range tc.bouncers { + apiClient.dbClient.Ent.Bouncer.Create(). SetIPAddress(fmt.Sprintf("1.2.3.%d", i)). SetName(bouncerName). SetAPIKey("foobar"). @@ -346,19 +333,17 @@ func TestAPICGetMetrics(t *testing.T) { ExecX(context.Background()) } - if foundMetrics, err := api.GetMetrics(); err != nil { - t.Fatal(err) - } else { - assert.Equal(t, foundMetrics.Bouncers, testCase.expectedMetric.Bouncers) - assert.Equal(t, foundMetrics.Machines, testCase.expectedMetric.Machines) + foundMetrics, err := apiClient.GetMetrics() + require.NoError(t, err) + + assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) + assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines) - } }) } } func TestCreateAlertsForDecision(t *testing.T) { - httpBfDecisionList := &models.Decision{ Origin: &SCOPE_LISTS, Scenario: types.StrPtr("crowdsecurity/http-bf"), @@ -427,10 +412,11 @@ func TestCreateAlertsForDecision(t *testing.T) { }, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := createAlertsForDecisions(tt.args.decisions); !reflect.DeepEqual(got, tt.want) { - t.Errorf("createAlertsForDecisions() = %v, want %v", got, tt.want) + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + if got := createAlertsForDecisions(tc.args.decisions); !reflect.DeepEqual(got, tc.want) { + t.Errorf("createAlertsForDecisions() = %v, want %v", got, tc.want) } }) } @@ -503,11 +489,12 @@ func TestFillAlertsWithDecisions(t *testing.T) { }, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - add_counters, _ := makeAddAndDeleteCounters() - if got := fillAlertsWithDecisions(tt.args.alerts, tt.args.decisions, add_counters); !reflect.DeepEqual(got, tt.want) { - t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tt.want) + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + addCounters, _ := makeAddAndDeleteCounters() + if got := fillAlertsWithDecisions(tc.args.alerts, tc.args.decisions, addCounters); !reflect.DeepEqual(got, tc.want) { + t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tc.want) } }) } @@ -586,24 +573,19 @@ func TestAPICPullTop(t *testing.T) { ), )) url, err := url.ParseRequestURI("http://api.crowdsec.net/") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + apic, err := apiclient.NewDefaultClient( url, "/api", fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), nil, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) api.apiClient = apic err = api.PullTop() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assertTotalDecisionCount(t, api.dbClient, 5) assertTotalValidDecisionCount(t, api.dbClient, 4) @@ -619,24 +601,23 @@ func TestAPICPullTop(t *testing.T) { for _, alert := range alerts { alertScenario[alert.SourceScope]++ } - assert.Equal(t, len(alertScenario), 3) - assert.Equal(t, alertScenario[SCOPE_CAPI_ALIAS], 1) - assert.Equal(t, alertScenario["lists:crowdsecurity/ssh-bf"], 1) - assert.Equal(t, alertScenario["lists:crowdsecurity/http-bf"], 1) + assert.Equal(t, 3, len(alertScenario)) + assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS]) + assert.Equal(t, 1, alertScenario["lists:crowdsecurity/ssh-bf"]) + assert.Equal(t, 1, alertScenario["lists:crowdsecurity/http-bf"]) for _, decisions := range validDecisions { decisionScenarioFreq[decisions.Scenario]++ } - assert.Equal(t, decisionScenarioFreq["crowdsecurity/http-bf"], 1) - assert.Equal(t, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1) - assert.Equal(t, decisionScenarioFreq["crowdsecurity/test1"], 1) - assert.Equal(t, decisionScenarioFreq["crowdsecurity/test2"], 1) + assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/http-bf"], 1) + assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1) + assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test1"], 1) + assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test2"], 1) } func TestAPICPush(t *testing.T) { - - testCases := []struct { + tests := []struct { name string alerts []*models.Alert expectedCalls int @@ -683,14 +664,14 @@ func TestAPICPush(t *testing.T) { }, } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { api := getAPIC(t) api.pushInterval = time.Millisecond url, err := url.ParseRequestURI("http://api.crowdsec.net/") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + httpmock.Activate() defer httpmock.DeactivateAndReset() apic, err := apiclient.NewDefaultClient( @@ -699,31 +680,28 @@ func TestAPICPush(t *testing.T) { fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), nil, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + api.apiClient = apic httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{})) go func() { - api.alertToPush <- testCase.alerts + api.alertToPush <- tc.alerts time.Sleep(time.Second) api.Shutdown() }() - if err := api.Push(); err != nil { - t.Fatal(err) - } - assert.Equal(t, httpmock.GetTotalCallCount(), testCase.expectedCalls) + err = api.Push() + require.NoError(t, err) + assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount()) }) } } func TestAPICSendMetrics(t *testing.T) { - api := getAPIC(t) - testCases := []struct { + tests := []struct { name string duration time.Duration expectedCalls int - setUp func() + setUp func(*apic) metricsInterval time.Duration }{ { @@ -731,14 +709,15 @@ func TestAPICSendMetrics(t *testing.T) { duration: time.Millisecond * 30, metricsInterval: time.Millisecond * 5, expectedCalls: 5, - setUp: func() {}, + setUp: func(api *apic) {}, }, { name: "with some metrics", duration: time.Millisecond * 30, metricsInterval: time.Millisecond * 5, expectedCalls: 5, - setUp: func() { + setUp: func(api *apic) { + api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) api.dbClient.Ent.Machine.Create(). SetMachineId("1234"). SetPassword(testPassword.String()). @@ -748,6 +727,7 @@ func TestAPICSendMetrics(t *testing.T) { SetUpdatedAt(time.Time{}). ExecX(context.Background()) + api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) api.dbClient.Ent.Bouncer.Create(). SetIPAddress("1.2.3.6"). SetName("someBouncer"). @@ -758,44 +738,49 @@ func TestAPICSendMetrics(t *testing.T) { }, }, } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - api = getAPIC(t) - api.pushInterval = time.Millisecond + + httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{})) + httpmock.Activate() + defer httpmock.Deactivate() + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { url, err := url.ParseRequestURI("http://api.crowdsec.net/") - if err != nil { - t.Fatal(err) - } - httpmock.Activate() - defer httpmock.DeactivateAndReset() - apic, err := apiclient.NewDefaultClient( + require.NoError(t, err) + + apiClient, err := apiclient.NewDefaultClient( url, "/api", fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), nil, ) - if err != nil { - t.Fatal(err) - } - api.apiClient = apic - api.metricsInterval = testCase.metricsInterval - httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, []byte{})) - testCase.setUp() + require.NoError(t, err) - go func() { - if err := api.SendMetrics(); err != nil { - panic(err) - } - }() - time.Sleep(testCase.duration) - assert.LessOrEqual(t, absDiff(testCase.expectedCalls, httpmock.GetTotalCallCount()), 2) + api := getAPIC(t) + api.pushInterval = time.Millisecond + api.apiClient = apiClient + api.metricsInterval = tc.metricsInterval + tc.setUp(api) + + stop := make(chan bool) + httpmock.ZeroCallCounters() + go api.SendMetrics(stop) + time.Sleep(tc.duration) + stop <- true + + info := httpmock.GetCallCountInfo() + noResponderCalls := info["NO_RESPONDER"] + responderCalls := info["POST http://api.crowdsec.net/api/metrics/"] + assert.LessOrEqual(t, absDiff(tc.expectedCalls, responderCalls), 2) + assert.Zero(t, noResponderCalls) }) } } func TestAPICPull(t *testing.T) { api := getAPIC(t) - testCases := []struct { + tests := []struct { name string setUp func() expectedDecisionCount int @@ -820,14 +805,13 @@ func TestAPICPull(t *testing.T) { }, } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { api = getAPIC(t) api.pullInterval = time.Millisecond url, err := url.ParseRequestURI("http://api.crowdsec.net/") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) httpmock.Activate() defer httpmock.DeactivateAndReset() apic, err := apiclient.NewDefaultClient( @@ -836,9 +820,7 @@ func TestAPICPull(t *testing.T) { fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), nil, ) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) api.apiClient = apic httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, jsonMarshalX( models.DecisionsStreamResponse{ @@ -854,7 +836,7 @@ func TestAPICPull(t *testing.T) { }, }, ))) - testCase.setUp() + tc.setUp() var buf bytes.Buffer go func() { logrus.SetOutput(&buf) @@ -865,15 +847,14 @@ func TestAPICPull(t *testing.T) { //Slightly long because the CI runner for windows are slow, and this can lead to random failure time.Sleep(time.Millisecond * 500) logrus.SetOutput(os.Stderr) - assert.Contains(t, buf.String(), testCase.logContains) - assertTotalDecisionCount(t, api.dbClient, testCase.expectedDecisionCount) + assert.Contains(t, buf.String(), tc.logContains) + assertTotalDecisionCount(t, api.dbClient, tc.expectedDecisionCount) }) } } func TestShouldShareAlert(t *testing.T) { - - testCases := []struct { + tests := []struct { name string consoleConfig *csconfig.ConsoleConfig alert *models.Alert @@ -948,10 +929,11 @@ func TestShouldShareAlert(t *testing.T) { }, } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - ret := shouldShareAlert(testCase.alert, testCase.consoleConfig) - assert.Equal(t, ret, testCase.expectedRet) + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + ret := shouldShareAlert(tc.alert, tc.consoleConfig) + assert.Equal(t, tc.expectedRet, ret) }) } } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 460fd45a0..d59454f0b 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -311,10 +311,7 @@ func (s *APIServer) Run(apiReady chan bool) error { return nil }) s.apic.metricsTomb.Go(func() error { - if err := s.apic.SendMetrics(); err != nil { - log.Errorf("capi metrics: %s", err) - return err - } + s.apic.SendMetrics(make(chan bool)) return nil }) } diff --git a/pkg/cstest/utils.go b/pkg/cstest/utils.go index 96ed256cc..f348903aa 100644 --- a/pkg/cstest/utils.go +++ b/pkg/cstest/utils.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Copy(sourceFile string, destinationFile string) error { @@ -110,6 +111,8 @@ func CopyDir(src string, dest string) error { } func AssertErrorContains(t *testing.T, err error, expectedErr string) { + t.Helper() + if expectedErr != "" { assert.ErrorContains(t, err, expectedErr) return @@ -117,3 +120,14 @@ func AssertErrorContains(t *testing.T, err error, expectedErr string) { assert.NoError(t, err) } + +func RequireErrorContains(t *testing.T, err error, expectedErr string) { + t.Helper() + + if expectedErr != "" { + require.ErrorContains(t, err, expectedErr) + return + } + + require.NoError(t, err) +}