cleanup + fix flaky tests in file_test.go, apic_test.go (#1773)

This commit is contained in:
mmetc 2022-09-30 16:01:42 +02:00 committed by GitHub
parent 6798dd7ba5
commit edced6818a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 391 additions and 389 deletions

View file

@ -1,4 +1,4 @@
package fileacquisition package fileacquisition_test
import ( import (
"fmt" "fmt"
@ -7,32 +7,39 @@ import (
"testing" "testing"
"time" "time"
fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file"
"github.com/crowdsecurity/crowdsec/pkg/cstest" "github.com/crowdsecurity/crowdsec/pkg/cstest"
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test" "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/tomb.v2" "gopkg.in/tomb.v2"
) )
func TestBadConfiguration(t *testing.T) { func TestBadConfiguration(t *testing.T) {
tests := []struct { tests := []struct {
name string
config string config string
expectedErr string expectedErr string
}{ }{
{ {
config: `foobar: asd.log`, name: "extra configuration key",
config: "foobar: asd.log",
expectedErr: "line 1: field foobar not found in type fileacquisition.FileConfiguration", expectedErr: "line 1: field foobar not found in type fileacquisition.FileConfiguration",
}, },
{ {
config: `mode: tail`, name: "missing filenames",
config: "mode: tail",
expectedErr: "no filename or filenames configuration provided", expectedErr: "no filename or filenames configuration provided",
}, },
{ {
name: "glob syntax error",
config: `filename: "[asd-.log"`, config: `filename: "[asd-.log"`,
expectedErr: "Glob failure: syntax error in pattern", expectedErr: "Glob failure: syntax error in pattern",
}, },
{ {
name: "bad exclude regexp",
config: `filenames: ["asd.log"] config: `filenames: ["asd.log"]
exclude_regexps: ["as[a-$d"]`, exclude_regexps: ["as[a-$d"]`,
expectedErr: "Could not compile regexp as", expectedErr: "Could not compile regexp as",
@ -42,20 +49,24 @@ exclude_regexps: ["as[a-$d"]`,
subLogger := log.WithFields(log.Fields{ subLogger := log.WithFields(log.Fields{
"type": "file", "type": "file",
}) })
for _, test := range tests {
f := FileSource{} for _, tc := range tests {
err := f.Configure([]byte(test.config), subLogger) tc := tc
assert.Contains(t, err.Error(), test.expectedErr) t.Run(tc.name, func(t *testing.T) {
f := fileacquisition.FileSource{}
err := f.Configure([]byte(tc.config), subLogger)
cstest.RequireErrorContains(t, err, tc.expectedErr)
})
} }
} }
func TestConfigureDSN(t *testing.T) { func TestConfigureDSN(t *testing.T) {
var file string file := "/etc/passwd"
if runtime.GOOS != "windows" {
file = "/etc/passwd" if runtime.GOOS == "windows" {
} else { file = `C:\Windows\System32\drivers\etc\hosts`
file = "C:\\Windows\\System32\\drivers\\etc\\hosts"
} }
tests := []struct { tests := []struct {
dsn string dsn string
expectedErr string expectedErr string
@ -69,37 +80,41 @@ func TestConfigureDSN(t *testing.T) {
expectedErr: "empty file:// DSN", expectedErr: "empty file:// DSN",
}, },
{ {
dsn: fmt.Sprintf("file://%s?log_level=warn", file), dsn: fmt.Sprintf("file://%s?log_level=warn", file),
expectedErr: "",
}, },
{ {
dsn: fmt.Sprintf("file://%s?log_level=foobar", file), dsn: fmt.Sprintf("file://%s?log_level=foobar", file),
expectedErr: "unknown level foobar: not a valid logrus Level:", expectedErr: "unknown level foobar: not a valid logrus Level:",
}, },
} }
subLogger := log.WithFields(log.Fields{ subLogger := log.WithFields(log.Fields{
"type": "file", "type": "file",
}) })
for _, test := range tests {
f := FileSource{} for _, tc := range tests {
err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger) tc := tc
cstest.AssertErrorContains(t, err, test.expectedErr) t.Run(tc.dsn, func(t *testing.T) {
f := fileacquisition.FileSource{}
err := f.ConfigureByDSN(tc.dsn, map[string]string{"type": "testtype"}, subLogger)
cstest.RequireErrorContains(t, err, tc.expectedErr)
})
} }
} }
func TestOneShot(t *testing.T) { func TestOneShot(t *testing.T) {
var permDeniedFile string permDeniedFile := "/etc/shadow"
var permDeniedError string permDeniedError := "failed opening /etc/shadow: open /etc/shadow: permission denied"
if runtime.GOOS != "windows" {
permDeniedFile = "/etc/shadow" if runtime.GOOS == "windows" {
permDeniedError = "failed opening /etc/shadow: open /etc/shadow: permission denied" // Technically, this is not a permission denied error, but we just want to test what happens
} else { // if we do not have access to the file
//Technically, this is not a permission denied error, but we just want to test what happens permDeniedFile = `C:\Windows\System32\config\SAM`
//if we do not have access to the file permDeniedError = `failed opening C:\Windows\System32\config\SAM: open C:\Windows\System32\config\SAM: The process cannot access the file because it is being used by another process.`
permDeniedFile = "C:\\Windows\\System32\\config\\SAM"
permDeniedError = "failed opening C:\\Windows\\System32\\config\\SAM: open C:\\Windows\\System32\\config\\SAM: The process cannot access the file because it is being used by another process."
} }
tests := []struct { tests := []struct {
name string
config string config string
expectedConfigErr string expectedConfigErr string
expectedErr string expectedErr string
@ -111,76 +126,68 @@ func TestOneShot(t *testing.T) {
teardown func() teardown func()
}{ }{
{ {
name: "permission denied",
config: fmt.Sprintf(` config: fmt.Sprintf(`
mode: cat mode: cat
filename: %s`, permDeniedFile), filename: %s`, permDeniedFile),
expectedConfigErr: "", expectedErr: permDeniedError,
expectedErr: permDeniedError, logLevel: log.WarnLevel,
expectedOutput: "", expectedLines: 0,
logLevel: log.WarnLevel,
expectedLines: 0,
}, },
{ {
name: "ignored directory",
config: ` config: `
mode: cat mode: cat
filename: /`, filename: /`,
expectedConfigErr: "", expectedOutput: "/ is a directory, ignoring it",
expectedErr: "", logLevel: log.WarnLevel,
expectedOutput: "/ is a directory, ignoring it", expectedLines: 0,
logLevel: log.WarnLevel,
expectedLines: 0,
}, },
{ {
name: "glob syntax error",
config: ` config: `
mode: cat mode: cat
filename: "[*-.log"`, filename: "[*-.log"`,
expectedConfigErr: "Glob failure: syntax error in pattern", expectedConfigErr: "Glob failure: syntax error in pattern",
expectedErr: "",
expectedOutput: "",
logLevel: log.WarnLevel, logLevel: log.WarnLevel,
expectedLines: 0, expectedLines: 0,
}, },
{ {
name: "no matching files",
config: ` config: `
mode: cat mode: cat
filename: /do/not/exist`, filename: /do/not/exist`,
expectedConfigErr: "", expectedOutput: "No matching files for pattern /do/not/exist",
expectedErr: "", logLevel: log.WarnLevel,
expectedOutput: "No matching files for pattern /do/not/exist", expectedLines: 0,
logLevel: log.WarnLevel,
expectedLines: 0,
}, },
{ {
name: "test.log",
config: ` config: `
mode: cat mode: cat
filename: test_files/test.log`, filename: test_files/test.log`,
expectedConfigErr: "", expectedLines: 5,
expectedErr: "", logLevel: log.WarnLevel,
expectedOutput: "",
expectedLines: 5,
logLevel: log.WarnLevel,
}, },
{ {
name: "test.log.gz",
config: ` config: `
mode: cat mode: cat
filename: test_files/test.log.gz`, filename: test_files/test.log.gz`,
expectedConfigErr: "", expectedLines: 5,
expectedErr: "", logLevel: log.WarnLevel,
expectedOutput: "",
expectedLines: 5,
logLevel: log.WarnLevel,
}, },
{ {
name: "unexpected end of gzip stream",
config: ` config: `
mode: cat mode: cat
filename: test_files/bad.gz`, filename: test_files/bad.gz`,
expectedConfigErr: "", expectedErr: "failed to read gz test_files/bad.gz: unexpected EOF",
expectedErr: "failed to read gz test_files/bad.gz: unexpected EOF", expectedLines: 0,
expectedOutput: "", logLevel: log.WarnLevel,
expectedLines: 0,
logLevel: log.WarnLevel,
}, },
{ {
name: "deleted file",
config: ` config: `
mode: cat mode: cat
filename: test_files/test_delete.log`, filename: test_files/test_delete.log`,
@ -195,77 +202,84 @@ filename: test_files/test_delete.log`,
}, },
} }
for _, ts := range tests { for _, tc := range tests {
logger, hook := test.NewNullLogger() tc := tc
logger.SetLevel(ts.logLevel) t.Run(tc.name, func(t *testing.T) {
subLogger := logger.WithFields(log.Fields{ logger, hook := test.NewNullLogger()
"type": "file", logger.SetLevel(tc.logLevel)
})
tomb := tomb.Tomb{}
out := make(chan types.Event)
f := FileSource{}
if ts.setup != nil {
ts.setup()
}
err := f.Configure([]byte(ts.config), subLogger)
cstest.AssertErrorContains(t, err, ts.expectedConfigErr)
if err != nil {
continue
}
if ts.afterConfigure != nil { subLogger := logger.WithFields(log.Fields{
ts.afterConfigure() "type": "file",
} })
actualLines := 0
if ts.expectedLines != 0 { tomb := tomb.Tomb{}
go func() { out := make(chan types.Event)
READLOOP: f := fileacquisition.FileSource{}
for {
select { if tc.setup != nil {
case <-out: tc.setup()
actualLines++ }
case <-time.After(1 * time.Second):
break READLOOP err := f.Configure([]byte(tc.config), subLogger)
cstest.RequireErrorContains(t, err, tc.expectedConfigErr)
if tc.expectedConfigErr != "" {
return
}
if tc.afterConfigure != nil {
tc.afterConfigure()
}
actualLines := 0
if tc.expectedLines != 0 {
go func() {
for {
select {
case <-out:
actualLines++
case <-time.After(2 * time.Second):
return
}
} }
} }()
}() }
}
err = f.OneShotAcquisition(out, &tomb)
cstest.AssertErrorContains(t, err, ts.expectedErr)
if ts.expectedLines != 0 { err = f.OneShotAcquisition(out, &tomb)
assert.Equal(t, actualLines, ts.expectedLines) cstest.RequireErrorContains(t, err, tc.expectedErr)
}
if ts.expectedOutput != "" { if tc.expectedLines != 0 {
assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) assert.Equal(t, tc.expectedLines, actualLines)
hook.Reset() }
}
if ts.teardown != nil { if tc.expectedOutput != "" {
ts.teardown() assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput)
} hook.Reset()
}
if tc.teardown != nil {
tc.teardown()
}
})
} }
} }
func TestLiveAcquisition(t *testing.T) { func TestLiveAcquisition(t *testing.T) {
var permDeniedFile string permDeniedFile := "/etc/shadow"
var permDeniedError string permDeniedError := "unable to read /etc/shadow : open /etc/shadow: permission denied"
var testPattern string testPattern := "test_files/*.log"
if runtime.GOOS != "windows" {
permDeniedFile = "/etc/shadow" if runtime.GOOS == "windows" {
permDeniedError = "unable to read /etc/shadow : open /etc/shadow: permission denied" // Technically, this is not a permission denied error, but we just want to test what happens
testPattern = "test_files/*.log" // if we do not have access to the file
} else { permDeniedFile = `C:\Windows\System32\config\SAM`
//Technically, this is not a permission denied error, but we just want to test what happens permDeniedError = `unable to read C:\Windows\System32\config\SAM : open C:\Windows\System32\config\SAM: The process cannot access the file because it is being used by another process`
//if we do not have access to the file testPattern = `test_files\\*.log` // the \ must be escaped for the yaml config
permDeniedFile = "C:\\Windows\\System32\\config\\SAM"
permDeniedError = "unable to read C:\\Windows\\System32\\config\\SAM : open C:\\Windows\\System32\\config\\SAM: The process cannot access the file because it is being used by another process"
testPattern = "test_files\\\\*.log" // the \ must be escaped twice: once for the string, once for the yaml config
} }
tests := []struct { tests := []struct {
name string
config string config string
expectedErr string expectedErr string
expectedOutput string expectedOutput string
name string
expectedLines int expectedLines int
logLevel log.Level logLevel log.Level
setup func() setup func()
@ -276,7 +290,6 @@ func TestLiveAcquisition(t *testing.T) {
config: fmt.Sprintf(` config: fmt.Sprintf(`
mode: tail mode: tail
filename: %s`, permDeniedFile), filename: %s`, permDeniedFile),
expectedErr: "",
expectedOutput: permDeniedError, expectedOutput: permDeniedError,
logLevel: log.InfoLevel, logLevel: log.InfoLevel,
expectedLines: 0, expectedLines: 0,
@ -286,7 +299,6 @@ filename: %s`, permDeniedFile),
config: ` config: `
mode: tail mode: tail
filename: /`, filename: /`,
expectedErr: "",
expectedOutput: "/ is a directory, ignoring it", expectedOutput: "/ is a directory, ignoring it",
logLevel: log.WarnLevel, logLevel: log.WarnLevel,
expectedLines: 0, expectedLines: 0,
@ -296,7 +308,6 @@ filename: /`,
config: ` config: `
mode: tail mode: tail
filename: /do/not/exist`, filename: /do/not/exist`,
expectedErr: "",
expectedOutput: "No matching files for pattern /do/not/exist", expectedOutput: "No matching files for pattern /do/not/exist",
logLevel: log.WarnLevel, logLevel: log.WarnLevel,
expectedLines: 0, expectedLines: 0,
@ -308,11 +319,9 @@ mode: tail
filenames: filenames:
- %s - %s
force_inotify: true`, testPattern), force_inotify: true`, testPattern),
expectedErr: "", expectedLines: 5,
expectedOutput: "", logLevel: log.DebugLevel,
expectedLines: 5, name: "basicGlob",
logLevel: log.DebugLevel,
name: "basicGlob",
}, },
{ {
config: fmt.Sprintf(` config: fmt.Sprintf(`
@ -320,11 +329,9 @@ mode: tail
filenames: filenames:
- %s - %s
force_inotify: true`, testPattern), force_inotify: true`, testPattern),
expectedErr: "", expectedLines: 0,
expectedOutput: "", logLevel: log.DebugLevel,
expectedLines: 0, name: "GlobInotify",
logLevel: log.DebugLevel,
name: "GlobInotify",
afterConfigure: func() { afterConfigure: func() {
f, _ := os.Create("test_files/a.log") f, _ := os.Create("test_files/a.log")
f.Close() f.Close()
@ -338,19 +345,17 @@ mode: tail
filenames: filenames:
- %s - %s
force_inotify: true`, testPattern), force_inotify: true`, testPattern),
expectedErr: "", expectedLines: 5,
expectedOutput: "", logLevel: log.DebugLevel,
expectedLines: 5, name: "GlobInotifyChmod",
logLevel: log.DebugLevel,
name: "GlobInotifyChmod",
afterConfigure: func() { afterConfigure: func() {
f, _ := os.Create("test_files/a.log") f, _ := os.Create("test_files/a.log")
f.Close() f.Close()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
os.Chmod("test_files/a.log", 0000) os.Chmod("test_files/a.log", 0o000)
}, },
teardown: func() { teardown: func() {
os.Chmod("test_files/a.log", 0644) os.Chmod("test_files/a.log", 0o644)
os.Remove("test_files/a.log") os.Remove("test_files/a.log")
}, },
}, },
@ -360,13 +365,11 @@ mode: tail
filenames: filenames:
- %s - %s
force_inotify: true`, testPattern), force_inotify: true`, testPattern),
expectedErr: "", expectedLines: 5,
expectedOutput: "", logLevel: log.DebugLevel,
expectedLines: 5, name: "InotifyMkDir",
logLevel: log.DebugLevel,
name: "InotifyMkDir",
afterConfigure: func() { afterConfigure: func() {
os.Mkdir("test_files/pouet/", 0700) os.Mkdir("test_files/pouet/", 0o700)
}, },
teardown: func() { teardown: func() {
os.Remove("test_files/pouet/") os.Remove("test_files/pouet/")
@ -374,101 +377,112 @@ force_inotify: true`, testPattern),
}, },
} }
for _, ts := range tests { for _, tc := range tests {
t.Logf("test: %s", ts.name) tc := tc
logger, hook := test.NewNullLogger() t.Run(tc.name, func(t *testing.T) {
logger.SetLevel(ts.logLevel) logger, hook := test.NewNullLogger()
subLogger := logger.WithFields(log.Fields{ logger.SetLevel(tc.logLevel)
"type": "file",
}) subLogger := logger.WithFields(log.Fields{
tomb := tomb.Tomb{} "type": "file",
out := make(chan types.Event) })
f := FileSource{}
if ts.setup != nil { tomb := tomb.Tomb{}
ts.setup() out := make(chan types.Event)
}
err := f.Configure([]byte(ts.config), subLogger) f := fileacquisition.FileSource{}
if err != nil {
t.Fatalf("Unexpected error : %s", err) if tc.setup != nil {
} tc.setup()
if ts.afterConfigure != nil { }
ts.afterConfigure()
} err := f.Configure([]byte(tc.config), subLogger)
actualLines := 0 require.NoError(t, err)
if ts.expectedLines != 0 {
go func() { if tc.afterConfigure != nil {
READLOOP: tc.afterConfigure()
for { }
select {
case <-out: actualLines := 0
actualLines++ if tc.expectedLines != 0 {
case <-time.After(2 * time.Second): go func() {
break READLOOP for {
select {
case <-out:
actualLines++
case <-time.After(2 * time.Second):
return
}
}
}()
}
err = f.StreamingAcquisition(out, &tomb)
cstest.RequireErrorContains(t, err, tc.expectedErr)
if tc.expectedLines != 0 {
fd, err := os.Create("test_files/stream.log")
if err != nil {
t.Fatalf("could not create test file : %s", err)
}
for i := 0; i < 5; i++ {
_, err = fmt.Fprintf(fd, "%d\n", i)
if err != nil {
t.Fatalf("could not write test file : %s", err)
os.Remove("test_files/stream.log")
} }
} }
}()
}
err = f.StreamingAcquisition(out, &tomb)
cstest.AssertErrorContains(t, err, ts.expectedErr)
if ts.expectedLines != 0 { fd.Close()
fd, err := os.Create("test_files/stream.log") // we sleep to make sure we detect the new file
if err != nil { time.Sleep(1 * time.Second)
t.Fatalf("could not create test file : %s", err) os.Remove("test_files/stream.log")
assert.Equal(t, tc.expectedLines, actualLines)
} }
for i := 0; i < 5; i++ {
_, err = fd.WriteString(fmt.Sprintf("%d\n", i)) if tc.expectedOutput != "" {
if err != nil { if hook.LastEntry() == nil {
t.Fatalf("could not write test file : %s", err) t.Fatalf("expected output %s, but got nothing", tc.expectedOutput)
os.Remove("test_files/stream.log")
} }
assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput)
hook.Reset()
} }
fd.Close()
//we sleep to make sure we detect the new file
time.Sleep(1 * time.Second)
os.Remove("test_files/stream.log")
assert.Equal(t, ts.expectedLines, actualLines)
}
if ts.expectedOutput != "" { if tc.teardown != nil {
if hook.LastEntry() == nil { tc.teardown()
t.Fatalf("expected output %s, but got nothing", ts.expectedOutput)
} }
assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
hook.Reset()
}
if ts.teardown != nil { tomb.Kill(nil)
ts.teardown() })
}
tomb.Kill(nil)
} }
} }
func TestExclusion(t *testing.T) { func TestExclusion(t *testing.T) {
config := `filenames: ["test_files/*.log*"] config := `filenames: ["test_files/*.log*"]
exclude_regexps: ["\\.gz$"]` exclude_regexps: ["\\.gz$"]`
logger, hook := test.NewNullLogger() logger, hook := test.NewNullLogger()
//logger.SetLevel(ts.logLevel) // logger.SetLevel(ts.logLevel)
subLogger := logger.WithFields(log.Fields{ subLogger := logger.WithFields(log.Fields{
"type": "file", "type": "file",
}) })
f := FileSource{}
err := f.Configure([]byte(config), subLogger) f := fileacquisition.FileSource{}
if err != nil { if err := f.Configure([]byte(config), subLogger); err != nil {
subLogger.Fatalf("unexpected error: %s", err) subLogger.Fatalf("unexpected error: %s", err)
} }
var expectedLogOutput string
expectedLogOutput := "Skipping file test_files/test.log.gz as it matches exclude pattern"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
expectedLogOutput = "Skipping file test_files\\test.log.gz as it matches exclude pattern \\.gz" expectedLogOutput = `Skipping file test_files\test.log.gz as it matches exclude pattern \.gz`
} else {
expectedLogOutput = "Skipping file test_files/test.log.gz as it matches exclude pattern"
} }
if hook.LastEntry() == nil { if hook.LastEntry() == nil {
t.Fatalf("expected output %s, but got nothing", expectedLogOutput) t.Fatalf("expected output %s, but got nothing", expectedLogOutput)
} }
assert.Contains(t, hook.LastEntry().Message, expectedLogOutput) assert.Contains(t, hook.LastEntry().Message, expectedLogOutput)
hook.Reset() hook.Reset()
} }

View file

@ -400,7 +400,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
return alerts return alerts
} }
//we receive only one list of decisions, that we need to break-up : // we receive only one list of decisions, that we need to break-up :
// one alert for "community blocklist" // one alert for "community blocklist"
// one alert per list we're subscribed to // one alert per list we're subscribed to
func (a *apic) PullTop() error { func (a *apic) PullTop() error {
@ -432,7 +432,7 @@ func (a *apic) PullTop() error {
return nil return nil
} }
//we receive only one list of decisions, that we need to break-up : // we receive only one list of decisions, that we need to break-up :
// one alert for "community blocklist" // one alert for "community blocklist"
// one alert per list we're subscribed to // one alert per list we're subscribed to
alertsFromCapi := createAlertsForDecisions(data.New) alertsFromCapi := createAlertsForDecisions(data.New)
@ -541,37 +541,32 @@ func (a *apic) GetMetrics() (*models.Metrics, error) {
return metric, nil return metric, nil
} }
func (a *apic) SendMetrics() error { func (a *apic) SendMetrics(stop chan (bool)) {
defer types.CatchPanic("lapi/metricsToAPIC") defer types.CatchPanic("lapi/metricsToAPIC")
metrics, err := a.GetMetrics() log.Infof("Start send metrics to CrowdSec Central API (interval: %s)", a.metricsInterval)
if err != nil {
log.Errorf("unable to get metrics (%s), will retry", err)
}
_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
if err != nil {
log.Errorf("unable to send metrics (%s), will retry", err)
}
log.Infof("capi metrics: metrics sent successfully")
log.Infof("Start send metrics to CrowdSec Central API (interval: %s)", MetricsInterval)
ticker := time.NewTicker(a.metricsInterval) ticker := time.NewTicker(a.metricsInterval)
for { for {
metrics, err := a.GetMetrics()
if err != nil {
log.Errorf("unable to get metrics (%s), will retry", err)
}
_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
if err != nil {
log.Errorf("capi metrics: failed: %s", err)
} else {
log.Infof("capi metrics: metrics sent successfully")
}
select { select {
case <-stop:
return
case <-ticker.C: case <-ticker.C:
metrics, err := a.GetMetrics() continue
if err != nil {
log.Errorf("unable to get metrics (%s), will retry", err)
}
_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
if err != nil {
log.Errorf("capi metrics: failed: %s", err)
} else {
log.Infof("capi metrics: metrics sent successfully")
}
case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others? case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others?
a.pullTomb.Kill(nil) a.pullTomb.Kill(nil)
a.pushTomb.Kill(nil) a.pushTomb.Kill(nil)
return nil return
} }
} }
} }

View file

@ -8,13 +8,13 @@ import (
"net/url" "net/url"
"os" "os"
"reflect" "reflect"
"sort"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cstest"
"github.com/crowdsecurity/crowdsec/pkg/cwversion" "github.com/crowdsecurity/crowdsec/pkg/cwversion"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
@ -24,23 +24,20 @@ import (
"github.com/jarcoal/httpmock" "github.com/jarcoal/httpmock"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/tomb.v2" "gopkg.in/tomb.v2"
) )
func getDBClient(t *testing.T) *database.Client { func getDBClient(t *testing.T) *database.Client {
t.Helper() t.Helper()
dbPath, err := os.CreateTemp("", "*sqlite") dbPath, err := os.CreateTemp("", "*sqlite")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{ dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
Type: "sqlite", Type: "sqlite",
DbName: "crowdsec", DbName: "crowdsec",
DbPath: dbPath.Name(), DbPath: dbPath.Name(),
}) })
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
return dbClient return dbClient
} }
@ -98,11 +95,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) {
func TestAPICCAPIPullIsOld(t *testing.T) { func TestAPICCAPIPullIsOld(t *testing.T) {
api := getAPIC(t) api := getAPIC(t)
isOld, err := api.CAPIPullIsOld()
if err != nil {
t.Fatal(err)
}
isOld, err := api.CAPIPullIsOld()
require.NoError(t, err)
assert.True(t, isOld) assert.True(t, isOld)
decision := api.dbClient.Ent.Decision.Create(). decision := api.dbClient.Ent.Decision.Create().
@ -123,16 +118,13 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
SaveX(context.Background()) SaveX(context.Background())
isOld, err = api.CAPIPullIsOld() isOld, err = api.CAPIPullIsOld()
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
assert.False(t, isOld) assert.False(t, isOld)
} }
func TestAPICFetchScenariosListFromDB(t *testing.T) { func TestAPICFetchScenariosListFromDB(t *testing.T) {
api := getAPIC(t) tests := []struct {
testCases := []struct {
name string name string
machineIDsWithScenarios map[string]string machineIDsWithScenarios map[string]string
expectedScenarios []string expectedScenarios []string
@ -154,8 +146,10 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
}, },
} }
for _, tc := range testCases { for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
api := getAPIC(t)
for machineID, scenarios := range tc.machineIDsWithScenarios { for machineID, scenarios := range tc.machineIDsWithScenarios {
api.dbClient.Ent.Machine.Create(). api.dbClient.Ent.Machine.Create().
SetMachineId(machineID). SetMachineId(machineID).
@ -164,17 +158,14 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
SetScenarios(scenarios). SetScenarios(scenarios).
ExecX(context.Background()) ExecX(context.Background())
} }
scenarios, err := api.FetchScenariosListFromDB() scenarios, err := api.FetchScenariosListFromDB()
for machineID := range tc.machineIDsWithScenarios { for machineID := range tc.machineIDsWithScenarios {
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background()) api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
} }
if err != nil { require.NoError(t, err)
t.Fatal(err)
} else { assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
sort.Strings(scenarios)
sort.Strings(tc.expectedScenarios)
assert.Equal(t, scenarios, tc.expectedScenarios)
}
}) })
} }
@ -196,11 +187,10 @@ func TestNewAPIC(t *testing.T) {
consoleConfig *csconfig.ConsoleConfig consoleConfig *csconfig.ConsoleConfig
} }
tests := []struct { tests := []struct {
name string name string
args args args args
wantErr bool expectedErr string
errorContains string action func()
action func()
}{ }{
{ {
name: "simple", name: "simple",
@ -217,20 +207,16 @@ func TestNewAPIC(t *testing.T) {
dbClient: getDBClient(t), dbClient: getDBClient(t),
consoleConfig: LoadTestConfig().API.Server.ConsoleConfig, consoleConfig: LoadTestConfig().API.Server.ConsoleConfig,
}, },
wantErr: true, expectedErr: "first path segment in URL cannot contain colon",
errorContains: "first path segment in URL cannot contain colon",
}, },
} }
for _, tt := range tests { for _, tc := range tests {
t.Run(tt.name, func(t *testing.T) { tc := tc
t.Run(tc.name, func(t *testing.T) {
setConfig() setConfig()
tt.action() tc.action()
_, err := NewAPIC(testConfig, tt.args.dbClient, tt.args.consoleConfig) _, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig)
if tt.wantErr { cstest.RequireErrorContains(t, err, tc.expectedErr)
assert.ErrorContains(t, err, tt.errorContains)
} else {
assert.NoError(t, err)
}
}) })
} }
} }
@ -268,17 +254,16 @@ func TestAPICHandleDeletedDecisions(t *testing.T) {
}}, deleteCounters) }}, deleteCounters)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, nbDeleted, 2) assert.Equal(t, 2, nbDeleted)
assert.Equal(t, deleteCounters[SCOPE_CAPI]["all"], 2) assert.Equal(t, 2, deleteCounters[SCOPE_CAPI]["all"])
} }
func TestAPICGetMetrics(t *testing.T) { func TestAPICGetMetrics(t *testing.T) {
api := getAPIC(t) cleanUp := func(api *apic) {
cleanUp := func() {
api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background()) api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
api.dbClient.Ent.Machine.Delete().ExecX(context.Background()) api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
} }
testCases := []struct { tests := []struct {
name string name string
machineIDs []string machineIDs []string
bouncers []string bouncers []string
@ -322,11 +307,13 @@ func TestAPICGetMetrics(t *testing.T) {
}, },
}, },
} }
for _, testCase := range testCases { for _, tc := range tests {
t.Run(testCase.name, func(t *testing.T) { tc := tc
cleanUp() t.Run(tc.name, func(t *testing.T) {
for i, machineID := range testCase.machineIDs { apiClient := getAPIC(t)
api.dbClient.Ent.Machine.Create(). cleanUp(apiClient)
for i, machineID := range tc.machineIDs {
apiClient.dbClient.Ent.Machine.Create().
SetMachineId(machineID). SetMachineId(machineID).
SetPassword(testPassword.String()). SetPassword(testPassword.String()).
SetIpAddress(fmt.Sprintf("1.2.3.%d", i)). SetIpAddress(fmt.Sprintf("1.2.3.%d", i)).
@ -336,8 +323,8 @@ func TestAPICGetMetrics(t *testing.T) {
ExecX(context.Background()) ExecX(context.Background())
} }
for i, bouncerName := range testCase.bouncers { for i, bouncerName := range tc.bouncers {
api.dbClient.Ent.Bouncer.Create(). apiClient.dbClient.Ent.Bouncer.Create().
SetIPAddress(fmt.Sprintf("1.2.3.%d", i)). SetIPAddress(fmt.Sprintf("1.2.3.%d", i)).
SetName(bouncerName). SetName(bouncerName).
SetAPIKey("foobar"). SetAPIKey("foobar").
@ -346,19 +333,17 @@ func TestAPICGetMetrics(t *testing.T) {
ExecX(context.Background()) ExecX(context.Background())
} }
if foundMetrics, err := api.GetMetrics(); err != nil { foundMetrics, err := apiClient.GetMetrics()
t.Fatal(err) require.NoError(t, err)
} else {
assert.Equal(t, foundMetrics.Bouncers, testCase.expectedMetric.Bouncers) assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers)
assert.Equal(t, foundMetrics.Machines, testCase.expectedMetric.Machines) assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines)
}
}) })
} }
} }
func TestCreateAlertsForDecision(t *testing.T) { func TestCreateAlertsForDecision(t *testing.T) {
httpBfDecisionList := &models.Decision{ httpBfDecisionList := &models.Decision{
Origin: &SCOPE_LISTS, Origin: &SCOPE_LISTS,
Scenario: types.StrPtr("crowdsecurity/http-bf"), Scenario: types.StrPtr("crowdsecurity/http-bf"),
@ -427,10 +412,11 @@ func TestCreateAlertsForDecision(t *testing.T) {
}, },
}, },
} }
for _, tt := range tests { for _, tc := range tests {
t.Run(tt.name, func(t *testing.T) { tc := tc
if got := createAlertsForDecisions(tt.args.decisions); !reflect.DeepEqual(got, tt.want) { t.Run(tc.name, func(t *testing.T) {
t.Errorf("createAlertsForDecisions() = %v, want %v", got, tt.want) if got := createAlertsForDecisions(tc.args.decisions); !reflect.DeepEqual(got, tc.want) {
t.Errorf("createAlertsForDecisions() = %v, want %v", got, tc.want)
} }
}) })
} }
@ -503,11 +489,12 @@ func TestFillAlertsWithDecisions(t *testing.T) {
}, },
}, },
} }
for _, tt := range tests { for _, tc := range tests {
t.Run(tt.name, func(t *testing.T) { tc := tc
add_counters, _ := makeAddAndDeleteCounters() t.Run(tc.name, func(t *testing.T) {
if got := fillAlertsWithDecisions(tt.args.alerts, tt.args.decisions, add_counters); !reflect.DeepEqual(got, tt.want) { addCounters, _ := makeAddAndDeleteCounters()
t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tt.want) if got := fillAlertsWithDecisions(tc.args.alerts, tc.args.decisions, addCounters); !reflect.DeepEqual(got, tc.want) {
t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tc.want)
} }
}) })
} }
@ -586,24 +573,19 @@ func TestAPICPullTop(t *testing.T) {
), ),
)) ))
url, err := url.ParseRequestURI("http://api.crowdsec.net/") url, err := url.ParseRequestURI("http://api.crowdsec.net/")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
apic, err := apiclient.NewDefaultClient( apic, err := apiclient.NewDefaultClient(
url, url,
"/api", "/api",
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil, nil,
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
api.apiClient = apic api.apiClient = apic
err = api.PullTop() err = api.PullTop()
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
assertTotalDecisionCount(t, api.dbClient, 5) assertTotalDecisionCount(t, api.dbClient, 5)
assertTotalValidDecisionCount(t, api.dbClient, 4) assertTotalValidDecisionCount(t, api.dbClient, 4)
@ -619,24 +601,23 @@ func TestAPICPullTop(t *testing.T) {
for _, alert := range alerts { for _, alert := range alerts {
alertScenario[alert.SourceScope]++ alertScenario[alert.SourceScope]++
} }
assert.Equal(t, len(alertScenario), 3) assert.Equal(t, 3, len(alertScenario))
assert.Equal(t, alertScenario[SCOPE_CAPI_ALIAS], 1) assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS])
assert.Equal(t, alertScenario["lists:crowdsecurity/ssh-bf"], 1) assert.Equal(t, 1, alertScenario["lists:crowdsecurity/ssh-bf"])
assert.Equal(t, alertScenario["lists:crowdsecurity/http-bf"], 1) assert.Equal(t, 1, alertScenario["lists:crowdsecurity/http-bf"])
for _, decisions := range validDecisions { for _, decisions := range validDecisions {
decisionScenarioFreq[decisions.Scenario]++ decisionScenarioFreq[decisions.Scenario]++
} }
assert.Equal(t, decisionScenarioFreq["crowdsecurity/http-bf"], 1) assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/http-bf"], 1)
assert.Equal(t, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1) assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1)
assert.Equal(t, decisionScenarioFreq["crowdsecurity/test1"], 1) assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test1"], 1)
assert.Equal(t, decisionScenarioFreq["crowdsecurity/test2"], 1) assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test2"], 1)
} }
func TestAPICPush(t *testing.T) { func TestAPICPush(t *testing.T) {
tests := []struct {
testCases := []struct {
name string name string
alerts []*models.Alert alerts []*models.Alert
expectedCalls int expectedCalls int
@ -683,14 +664,14 @@ func TestAPICPush(t *testing.T) {
}, },
} }
for _, testCase := range testCases { for _, tc := range tests {
t.Run(testCase.name, func(t *testing.T) { tc := tc
t.Run(tc.name, func(t *testing.T) {
api := getAPIC(t) api := getAPIC(t)
api.pushInterval = time.Millisecond api.pushInterval = time.Millisecond
url, err := url.ParseRequestURI("http://api.crowdsec.net/") url, err := url.ParseRequestURI("http://api.crowdsec.net/")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()
apic, err := apiclient.NewDefaultClient( apic, err := apiclient.NewDefaultClient(
@ -699,31 +680,28 @@ func TestAPICPush(t *testing.T) {
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil, nil,
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
api.apiClient = apic api.apiClient = apic
httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{})) httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{}))
go func() { go func() {
api.alertToPush <- testCase.alerts api.alertToPush <- tc.alerts
time.Sleep(time.Second) time.Sleep(time.Second)
api.Shutdown() api.Shutdown()
}() }()
if err := api.Push(); err != nil { err = api.Push()
t.Fatal(err) require.NoError(t, err)
} assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount())
assert.Equal(t, httpmock.GetTotalCallCount(), testCase.expectedCalls)
}) })
} }
} }
func TestAPICSendMetrics(t *testing.T) { func TestAPICSendMetrics(t *testing.T) {
api := getAPIC(t) tests := []struct {
testCases := []struct {
name string name string
duration time.Duration duration time.Duration
expectedCalls int expectedCalls int
setUp func() setUp func(*apic)
metricsInterval time.Duration metricsInterval time.Duration
}{ }{
{ {
@ -731,14 +709,15 @@ func TestAPICSendMetrics(t *testing.T) {
duration: time.Millisecond * 30, duration: time.Millisecond * 30,
metricsInterval: time.Millisecond * 5, metricsInterval: time.Millisecond * 5,
expectedCalls: 5, expectedCalls: 5,
setUp: func() {}, setUp: func(api *apic) {},
}, },
{ {
name: "with some metrics", name: "with some metrics",
duration: time.Millisecond * 30, duration: time.Millisecond * 30,
metricsInterval: time.Millisecond * 5, metricsInterval: time.Millisecond * 5,
expectedCalls: 5, expectedCalls: 5,
setUp: func() { setUp: func(api *apic) {
api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
api.dbClient.Ent.Machine.Create(). api.dbClient.Ent.Machine.Create().
SetMachineId("1234"). SetMachineId("1234").
SetPassword(testPassword.String()). SetPassword(testPassword.String()).
@ -748,6 +727,7 @@ func TestAPICSendMetrics(t *testing.T) {
SetUpdatedAt(time.Time{}). SetUpdatedAt(time.Time{}).
ExecX(context.Background()) ExecX(context.Background())
api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
api.dbClient.Ent.Bouncer.Create(). api.dbClient.Ent.Bouncer.Create().
SetIPAddress("1.2.3.6"). SetIPAddress("1.2.3.6").
SetName("someBouncer"). SetName("someBouncer").
@ -758,44 +738,49 @@ func TestAPICSendMetrics(t *testing.T) {
}, },
}, },
} }
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{}))
api = getAPIC(t) httpmock.Activate()
api.pushInterval = time.Millisecond defer httpmock.Deactivate()
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
url, err := url.ParseRequestURI("http://api.crowdsec.net/") url, err := url.ParseRequestURI("http://api.crowdsec.net/")
if err != nil { require.NoError(t, err)
t.Fatal(err)
} apiClient, err := apiclient.NewDefaultClient(
httpmock.Activate()
defer httpmock.DeactivateAndReset()
apic, err := apiclient.NewDefaultClient(
url, url,
"/api", "/api",
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil, nil,
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
api.apiClient = apic
api.metricsInterval = testCase.metricsInterval
httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, []byte{}))
testCase.setUp()
go func() { api := getAPIC(t)
if err := api.SendMetrics(); err != nil { api.pushInterval = time.Millisecond
panic(err) api.apiClient = apiClient
} api.metricsInterval = tc.metricsInterval
}() tc.setUp(api)
time.Sleep(testCase.duration)
assert.LessOrEqual(t, absDiff(testCase.expectedCalls, httpmock.GetTotalCallCount()), 2) stop := make(chan bool)
httpmock.ZeroCallCounters()
go api.SendMetrics(stop)
time.Sleep(tc.duration)
stop <- true
info := httpmock.GetCallCountInfo()
noResponderCalls := info["NO_RESPONDER"]
responderCalls := info["POST http://api.crowdsec.net/api/metrics/"]
assert.LessOrEqual(t, absDiff(tc.expectedCalls, responderCalls), 2)
assert.Zero(t, noResponderCalls)
}) })
} }
} }
func TestAPICPull(t *testing.T) { func TestAPICPull(t *testing.T) {
api := getAPIC(t) api := getAPIC(t)
testCases := []struct { tests := []struct {
name string name string
setUp func() setUp func()
expectedDecisionCount int expectedDecisionCount int
@ -820,14 +805,13 @@ func TestAPICPull(t *testing.T) {
}, },
} }
for _, testCase := range testCases { for _, tc := range tests {
t.Run(testCase.name, func(t *testing.T) { tc := tc
t.Run(tc.name, func(t *testing.T) {
api = getAPIC(t) api = getAPIC(t)
api.pullInterval = time.Millisecond api.pullInterval = time.Millisecond
url, err := url.ParseRequestURI("http://api.crowdsec.net/") url, err := url.ParseRequestURI("http://api.crowdsec.net/")
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()
apic, err := apiclient.NewDefaultClient( apic, err := apiclient.NewDefaultClient(
@ -836,9 +820,7 @@ func TestAPICPull(t *testing.T) {
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()), fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil, nil,
) )
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
api.apiClient = apic api.apiClient = apic
httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, jsonMarshalX( httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, jsonMarshalX(
models.DecisionsStreamResponse{ models.DecisionsStreamResponse{
@ -854,7 +836,7 @@ func TestAPICPull(t *testing.T) {
}, },
}, },
))) )))
testCase.setUp() tc.setUp()
var buf bytes.Buffer var buf bytes.Buffer
go func() { go func() {
logrus.SetOutput(&buf) logrus.SetOutput(&buf)
@ -865,15 +847,14 @@ func TestAPICPull(t *testing.T) {
//Slightly long because the CI runner for windows are slow, and this can lead to random failure //Slightly long because the CI runner for windows are slow, and this can lead to random failure
time.Sleep(time.Millisecond * 500) time.Sleep(time.Millisecond * 500)
logrus.SetOutput(os.Stderr) logrus.SetOutput(os.Stderr)
assert.Contains(t, buf.String(), testCase.logContains) assert.Contains(t, buf.String(), tc.logContains)
assertTotalDecisionCount(t, api.dbClient, testCase.expectedDecisionCount) assertTotalDecisionCount(t, api.dbClient, tc.expectedDecisionCount)
}) })
} }
} }
func TestShouldShareAlert(t *testing.T) { func TestShouldShareAlert(t *testing.T) {
tests := []struct {
testCases := []struct {
name string name string
consoleConfig *csconfig.ConsoleConfig consoleConfig *csconfig.ConsoleConfig
alert *models.Alert alert *models.Alert
@ -948,10 +929,11 @@ func TestShouldShareAlert(t *testing.T) {
}, },
} }
for _, testCase := range testCases { for _, tc := range tests {
t.Run(testCase.name, func(t *testing.T) { tc := tc
ret := shouldShareAlert(testCase.alert, testCase.consoleConfig) t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, ret, testCase.expectedRet) ret := shouldShareAlert(tc.alert, tc.consoleConfig)
assert.Equal(t, tc.expectedRet, ret)
}) })
} }
} }

View file

@ -311,10 +311,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
return nil return nil
}) })
s.apic.metricsTomb.Go(func() error { s.apic.metricsTomb.Go(func() error {
if err := s.apic.SendMetrics(); err != nil { s.apic.SendMetrics(make(chan bool))
log.Errorf("capi metrics: %s", err)
return err
}
return nil return nil
}) })
} }

View file

@ -7,6 +7,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Copy(sourceFile string, destinationFile string) error { func Copy(sourceFile string, destinationFile string) error {
@ -110,6 +111,8 @@ func CopyDir(src string, dest string) error {
} }
func AssertErrorContains(t *testing.T, err error, expectedErr string) { func AssertErrorContains(t *testing.T, err error, expectedErr string) {
t.Helper()
if expectedErr != "" { if expectedErr != "" {
assert.ErrorContains(t, err, expectedErr) assert.ErrorContains(t, err, expectedErr)
return return
@ -117,3 +120,14 @@ func AssertErrorContains(t *testing.T, err error, expectedErr string) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func RequireErrorContains(t *testing.T, err error, expectedErr string) {
t.Helper()
if expectedErr != "" {
require.ErrorContains(t, err, expectedErr)
return
}
require.NoError(t, err)
}