cleanup + fix flaky tests in file_test.go, apic_test.go (#1773)

This commit is contained in:
mmetc 2022-09-30 16:01:42 +02:00 committed by GitHub
parent 6798dd7ba5
commit edced6818a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 391 additions and 389 deletions

View file

@ -1,4 +1,4 @@
package fileacquisition
package fileacquisition_test
import (
"fmt"
@ -7,32 +7,39 @@ import (
"testing"
"time"
fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file"
"github.com/crowdsecurity/crowdsec/pkg/cstest"
"github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/tomb.v2"
)
func TestBadConfiguration(t *testing.T) {
tests := []struct {
name string
config string
expectedErr string
}{
{
config: `foobar: asd.log`,
name: "extra configuration key",
config: "foobar: asd.log",
expectedErr: "line 1: field foobar not found in type fileacquisition.FileConfiguration",
},
{
config: `mode: tail`,
name: "missing filenames",
config: "mode: tail",
expectedErr: "no filename or filenames configuration provided",
},
{
name: "glob syntax error",
config: `filename: "[asd-.log"`,
expectedErr: "Glob failure: syntax error in pattern",
},
{
name: "bad exclude regexp",
config: `filenames: ["asd.log"]
exclude_regexps: ["as[a-$d"]`,
expectedErr: "Could not compile regexp as",
@ -42,20 +49,24 @@ exclude_regexps: ["as[a-$d"]`,
subLogger := log.WithFields(log.Fields{
"type": "file",
})
for _, test := range tests {
f := FileSource{}
err := f.Configure([]byte(test.config), subLogger)
assert.Contains(t, err.Error(), test.expectedErr)
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
f := fileacquisition.FileSource{}
err := f.Configure([]byte(tc.config), subLogger)
cstest.RequireErrorContains(t, err, tc.expectedErr)
})
}
}
func TestConfigureDSN(t *testing.T) {
var file string
if runtime.GOOS != "windows" {
file = "/etc/passwd"
} else {
file = "C:\\Windows\\System32\\drivers\\etc\\hosts"
file := "/etc/passwd"
if runtime.GOOS == "windows" {
file = `C:\Windows\System32\drivers\etc\hosts`
}
tests := []struct {
dsn string
expectedErr string
@ -69,37 +80,41 @@ func TestConfigureDSN(t *testing.T) {
expectedErr: "empty file:// DSN",
},
{
dsn: fmt.Sprintf("file://%s?log_level=warn", file),
expectedErr: "",
dsn: fmt.Sprintf("file://%s?log_level=warn", file),
},
{
dsn: fmt.Sprintf("file://%s?log_level=foobar", file),
expectedErr: "unknown level foobar: not a valid logrus Level:",
},
}
subLogger := log.WithFields(log.Fields{
"type": "file",
})
for _, test := range tests {
f := FileSource{}
err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger)
cstest.AssertErrorContains(t, err, test.expectedErr)
for _, tc := range tests {
tc := tc
t.Run(tc.dsn, func(t *testing.T) {
f := fileacquisition.FileSource{}
err := f.ConfigureByDSN(tc.dsn, map[string]string{"type": "testtype"}, subLogger)
cstest.RequireErrorContains(t, err, tc.expectedErr)
})
}
}
func TestOneShot(t *testing.T) {
var permDeniedFile string
var permDeniedError string
if runtime.GOOS != "windows" {
permDeniedFile = "/etc/shadow"
permDeniedError = "failed opening /etc/shadow: open /etc/shadow: permission denied"
} else {
//Technically, this is not a permission denied error, but we just want to test what happens
//if we do not have access to the file
permDeniedFile = "C:\\Windows\\System32\\config\\SAM"
permDeniedError = "failed opening C:\\Windows\\System32\\config\\SAM: open C:\\Windows\\System32\\config\\SAM: The process cannot access the file because it is being used by another process."
permDeniedFile := "/etc/shadow"
permDeniedError := "failed opening /etc/shadow: open /etc/shadow: permission denied"
if runtime.GOOS == "windows" {
// Technically, this is not a permission denied error, but we just want to test what happens
// if we do not have access to the file
permDeniedFile = `C:\Windows\System32\config\SAM`
permDeniedError = `failed opening C:\Windows\System32\config\SAM: open C:\Windows\System32\config\SAM: The process cannot access the file because it is being used by another process.`
}
tests := []struct {
name string
config string
expectedConfigErr string
expectedErr string
@ -111,76 +126,68 @@ func TestOneShot(t *testing.T) {
teardown func()
}{
{
name: "permission denied",
config: fmt.Sprintf(`
mode: cat
filename: %s`, permDeniedFile),
expectedConfigErr: "",
expectedErr: permDeniedError,
expectedOutput: "",
logLevel: log.WarnLevel,
expectedLines: 0,
expectedErr: permDeniedError,
logLevel: log.WarnLevel,
expectedLines: 0,
},
{
name: "ignored directory",
config: `
mode: cat
filename: /`,
expectedConfigErr: "",
expectedErr: "",
expectedOutput: "/ is a directory, ignoring it",
logLevel: log.WarnLevel,
expectedLines: 0,
expectedOutput: "/ is a directory, ignoring it",
logLevel: log.WarnLevel,
expectedLines: 0,
},
{
name: "glob syntax error",
config: `
mode: cat
filename: "[*-.log"`,
expectedConfigErr: "Glob failure: syntax error in pattern",
expectedErr: "",
expectedOutput: "",
logLevel: log.WarnLevel,
expectedLines: 0,
},
{
name: "no matching files",
config: `
mode: cat
filename: /do/not/exist`,
expectedConfigErr: "",
expectedErr: "",
expectedOutput: "No matching files for pattern /do/not/exist",
logLevel: log.WarnLevel,
expectedLines: 0,
expectedOutput: "No matching files for pattern /do/not/exist",
logLevel: log.WarnLevel,
expectedLines: 0,
},
{
name: "test.log",
config: `
mode: cat
filename: test_files/test.log`,
expectedConfigErr: "",
expectedErr: "",
expectedOutput: "",
expectedLines: 5,
logLevel: log.WarnLevel,
expectedLines: 5,
logLevel: log.WarnLevel,
},
{
name: "test.log.gz",
config: `
mode: cat
filename: test_files/test.log.gz`,
expectedConfigErr: "",
expectedErr: "",
expectedOutput: "",
expectedLines: 5,
logLevel: log.WarnLevel,
expectedLines: 5,
logLevel: log.WarnLevel,
},
{
name: "unexpected end of gzip stream",
config: `
mode: cat
filename: test_files/bad.gz`,
expectedConfigErr: "",
expectedErr: "failed to read gz test_files/bad.gz: unexpected EOF",
expectedOutput: "",
expectedLines: 0,
logLevel: log.WarnLevel,
expectedErr: "failed to read gz test_files/bad.gz: unexpected EOF",
expectedLines: 0,
logLevel: log.WarnLevel,
},
{
name: "deleted file",
config: `
mode: cat
filename: test_files/test_delete.log`,
@ -195,77 +202,84 @@ filename: test_files/test_delete.log`,
},
}
for _, ts := range tests {
logger, hook := test.NewNullLogger()
logger.SetLevel(ts.logLevel)
subLogger := logger.WithFields(log.Fields{
"type": "file",
})
tomb := tomb.Tomb{}
out := make(chan types.Event)
f := FileSource{}
if ts.setup != nil {
ts.setup()
}
err := f.Configure([]byte(ts.config), subLogger)
cstest.AssertErrorContains(t, err, ts.expectedConfigErr)
if err != nil {
continue
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
logger, hook := test.NewNullLogger()
logger.SetLevel(tc.logLevel)
if ts.afterConfigure != nil {
ts.afterConfigure()
}
actualLines := 0
if ts.expectedLines != 0 {
go func() {
READLOOP:
for {
select {
case <-out:
actualLines++
case <-time.After(1 * time.Second):
break READLOOP
subLogger := logger.WithFields(log.Fields{
"type": "file",
})
tomb := tomb.Tomb{}
out := make(chan types.Event)
f := fileacquisition.FileSource{}
if tc.setup != nil {
tc.setup()
}
err := f.Configure([]byte(tc.config), subLogger)
cstest.RequireErrorContains(t, err, tc.expectedConfigErr)
if tc.expectedConfigErr != "" {
return
}
if tc.afterConfigure != nil {
tc.afterConfigure()
}
actualLines := 0
if tc.expectedLines != 0 {
go func() {
for {
select {
case <-out:
actualLines++
case <-time.After(2 * time.Second):
return
}
}
}
}()
}
err = f.OneShotAcquisition(out, &tomb)
cstest.AssertErrorContains(t, err, ts.expectedErr)
}()
}
if ts.expectedLines != 0 {
assert.Equal(t, actualLines, ts.expectedLines)
}
if ts.expectedOutput != "" {
assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
hook.Reset()
}
if ts.teardown != nil {
ts.teardown()
}
err = f.OneShotAcquisition(out, &tomb)
cstest.RequireErrorContains(t, err, tc.expectedErr)
if tc.expectedLines != 0 {
assert.Equal(t, tc.expectedLines, actualLines)
}
if tc.expectedOutput != "" {
assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput)
hook.Reset()
}
if tc.teardown != nil {
tc.teardown()
}
})
}
}
func TestLiveAcquisition(t *testing.T) {
var permDeniedFile string
var permDeniedError string
var testPattern string
if runtime.GOOS != "windows" {
permDeniedFile = "/etc/shadow"
permDeniedError = "unable to read /etc/shadow : open /etc/shadow: permission denied"
testPattern = "test_files/*.log"
} else {
//Technically, this is not a permission denied error, but we just want to test what happens
//if we do not have access to the file
permDeniedFile = "C:\\Windows\\System32\\config\\SAM"
permDeniedError = "unable to read C:\\Windows\\System32\\config\\SAM : open C:\\Windows\\System32\\config\\SAM: The process cannot access the file because it is being used by another process"
testPattern = "test_files\\\\*.log" // the \ must be escaped twice: once for the string, once for the yaml config
permDeniedFile := "/etc/shadow"
permDeniedError := "unable to read /etc/shadow : open /etc/shadow: permission denied"
testPattern := "test_files/*.log"
if runtime.GOOS == "windows" {
// Technically, this is not a permission denied error, but we just want to test what happens
// if we do not have access to the file
permDeniedFile = `C:\Windows\System32\config\SAM`
permDeniedError = `unable to read C:\Windows\System32\config\SAM : open C:\Windows\System32\config\SAM: The process cannot access the file because it is being used by another process`
testPattern = `test_files\\*.log` // the \ must be escaped for the yaml config
}
tests := []struct {
name string
config string
expectedErr string
expectedOutput string
name string
expectedLines int
logLevel log.Level
setup func()
@ -276,7 +290,6 @@ func TestLiveAcquisition(t *testing.T) {
config: fmt.Sprintf(`
mode: tail
filename: %s`, permDeniedFile),
expectedErr: "",
expectedOutput: permDeniedError,
logLevel: log.InfoLevel,
expectedLines: 0,
@ -286,7 +299,6 @@ filename: %s`, permDeniedFile),
config: `
mode: tail
filename: /`,
expectedErr: "",
expectedOutput: "/ is a directory, ignoring it",
logLevel: log.WarnLevel,
expectedLines: 0,
@ -296,7 +308,6 @@ filename: /`,
config: `
mode: tail
filename: /do/not/exist`,
expectedErr: "",
expectedOutput: "No matching files for pattern /do/not/exist",
logLevel: log.WarnLevel,
expectedLines: 0,
@ -308,11 +319,9 @@ mode: tail
filenames:
- %s
force_inotify: true`, testPattern),
expectedErr: "",
expectedOutput: "",
expectedLines: 5,
logLevel: log.DebugLevel,
name: "basicGlob",
expectedLines: 5,
logLevel: log.DebugLevel,
name: "basicGlob",
},
{
config: fmt.Sprintf(`
@ -320,11 +329,9 @@ mode: tail
filenames:
- %s
force_inotify: true`, testPattern),
expectedErr: "",
expectedOutput: "",
expectedLines: 0,
logLevel: log.DebugLevel,
name: "GlobInotify",
expectedLines: 0,
logLevel: log.DebugLevel,
name: "GlobInotify",
afterConfigure: func() {
f, _ := os.Create("test_files/a.log")
f.Close()
@ -338,19 +345,17 @@ mode: tail
filenames:
- %s
force_inotify: true`, testPattern),
expectedErr: "",
expectedOutput: "",
expectedLines: 5,
logLevel: log.DebugLevel,
name: "GlobInotifyChmod",
expectedLines: 5,
logLevel: log.DebugLevel,
name: "GlobInotifyChmod",
afterConfigure: func() {
f, _ := os.Create("test_files/a.log")
f.Close()
time.Sleep(1 * time.Second)
os.Chmod("test_files/a.log", 0000)
os.Chmod("test_files/a.log", 0o000)
},
teardown: func() {
os.Chmod("test_files/a.log", 0644)
os.Chmod("test_files/a.log", 0o644)
os.Remove("test_files/a.log")
},
},
@ -360,13 +365,11 @@ mode: tail
filenames:
- %s
force_inotify: true`, testPattern),
expectedErr: "",
expectedOutput: "",
expectedLines: 5,
logLevel: log.DebugLevel,
name: "InotifyMkDir",
expectedLines: 5,
logLevel: log.DebugLevel,
name: "InotifyMkDir",
afterConfigure: func() {
os.Mkdir("test_files/pouet/", 0700)
os.Mkdir("test_files/pouet/", 0o700)
},
teardown: func() {
os.Remove("test_files/pouet/")
@ -374,101 +377,112 @@ force_inotify: true`, testPattern),
},
}
for _, ts := range tests {
t.Logf("test: %s", ts.name)
logger, hook := test.NewNullLogger()
logger.SetLevel(ts.logLevel)
subLogger := logger.WithFields(log.Fields{
"type": "file",
})
tomb := tomb.Tomb{}
out := make(chan types.Event)
f := FileSource{}
if ts.setup != nil {
ts.setup()
}
err := f.Configure([]byte(ts.config), subLogger)
if err != nil {
t.Fatalf("Unexpected error : %s", err)
}
if ts.afterConfigure != nil {
ts.afterConfigure()
}
actualLines := 0
if ts.expectedLines != 0 {
go func() {
READLOOP:
for {
select {
case <-out:
actualLines++
case <-time.After(2 * time.Second):
break READLOOP
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
logger, hook := test.NewNullLogger()
logger.SetLevel(tc.logLevel)
subLogger := logger.WithFields(log.Fields{
"type": "file",
})
tomb := tomb.Tomb{}
out := make(chan types.Event)
f := fileacquisition.FileSource{}
if tc.setup != nil {
tc.setup()
}
err := f.Configure([]byte(tc.config), subLogger)
require.NoError(t, err)
if tc.afterConfigure != nil {
tc.afterConfigure()
}
actualLines := 0
if tc.expectedLines != 0 {
go func() {
for {
select {
case <-out:
actualLines++
case <-time.After(2 * time.Second):
return
}
}
}()
}
err = f.StreamingAcquisition(out, &tomb)
cstest.RequireErrorContains(t, err, tc.expectedErr)
if tc.expectedLines != 0 {
fd, err := os.Create("test_files/stream.log")
if err != nil {
t.Fatalf("could not create test file : %s", err)
}
for i := 0; i < 5; i++ {
_, err = fmt.Fprintf(fd, "%d\n", i)
if err != nil {
t.Fatalf("could not write test file : %s", err)
os.Remove("test_files/stream.log")
}
}
}()
}
err = f.StreamingAcquisition(out, &tomb)
cstest.AssertErrorContains(t, err, ts.expectedErr)
if ts.expectedLines != 0 {
fd, err := os.Create("test_files/stream.log")
if err != nil {
t.Fatalf("could not create test file : %s", err)
fd.Close()
// we sleep to make sure we detect the new file
time.Sleep(1 * time.Second)
os.Remove("test_files/stream.log")
assert.Equal(t, tc.expectedLines, actualLines)
}
for i := 0; i < 5; i++ {
_, err = fd.WriteString(fmt.Sprintf("%d\n", i))
if err != nil {
t.Fatalf("could not write test file : %s", err)
os.Remove("test_files/stream.log")
if tc.expectedOutput != "" {
if hook.LastEntry() == nil {
t.Fatalf("expected output %s, but got nothing", tc.expectedOutput)
}
assert.Contains(t, hook.LastEntry().Message, tc.expectedOutput)
hook.Reset()
}
fd.Close()
//we sleep to make sure we detect the new file
time.Sleep(1 * time.Second)
os.Remove("test_files/stream.log")
assert.Equal(t, ts.expectedLines, actualLines)
}
if ts.expectedOutput != "" {
if hook.LastEntry() == nil {
t.Fatalf("expected output %s, but got nothing", ts.expectedOutput)
if tc.teardown != nil {
tc.teardown()
}
assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
hook.Reset()
}
if ts.teardown != nil {
ts.teardown()
}
tomb.Kill(nil)
tomb.Kill(nil)
})
}
}
func TestExclusion(t *testing.T) {
config := `filenames: ["test_files/*.log*"]
exclude_regexps: ["\\.gz$"]`
logger, hook := test.NewNullLogger()
//logger.SetLevel(ts.logLevel)
// logger.SetLevel(ts.logLevel)
subLogger := logger.WithFields(log.Fields{
"type": "file",
})
f := FileSource{}
err := f.Configure([]byte(config), subLogger)
if err != nil {
f := fileacquisition.FileSource{}
if err := f.Configure([]byte(config), subLogger); err != nil {
subLogger.Fatalf("unexpected error: %s", err)
}
var expectedLogOutput string
expectedLogOutput := "Skipping file test_files/test.log.gz as it matches exclude pattern"
if runtime.GOOS == "windows" {
expectedLogOutput = "Skipping file test_files\\test.log.gz as it matches exclude pattern \\.gz"
} else {
expectedLogOutput = "Skipping file test_files/test.log.gz as it matches exclude pattern"
expectedLogOutput = `Skipping file test_files\test.log.gz as it matches exclude pattern \.gz`
}
if hook.LastEntry() == nil {
t.Fatalf("expected output %s, but got nothing", expectedLogOutput)
}
assert.Contains(t, hook.LastEntry().Message, expectedLogOutput)
hook.Reset()
}

View file

@ -400,7 +400,7 @@ func fillAlertsWithDecisions(alerts []*models.Alert, decisions []*models.Decisio
return alerts
}
//we receive only one list of decisions, that we need to break-up :
// we receive only one list of decisions, that we need to break-up :
// one alert for "community blocklist"
// one alert per list we're subscribed to
func (a *apic) PullTop() error {
@ -432,7 +432,7 @@ func (a *apic) PullTop() error {
return nil
}
//we receive only one list of decisions, that we need to break-up :
// we receive only one list of decisions, that we need to break-up :
// one alert for "community blocklist"
// one alert per list we're subscribed to
alertsFromCapi := createAlertsForDecisions(data.New)
@ -541,37 +541,32 @@ func (a *apic) GetMetrics() (*models.Metrics, error) {
return metric, nil
}
func (a *apic) SendMetrics() error {
func (a *apic) SendMetrics(stop chan (bool)) {
defer types.CatchPanic("lapi/metricsToAPIC")
metrics, err := a.GetMetrics()
if err != nil {
log.Errorf("unable to get metrics (%s), will retry", err)
}
_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
if err != nil {
log.Errorf("unable to send metrics (%s), will retry", err)
}
log.Infof("capi metrics: metrics sent successfully")
log.Infof("Start send metrics to CrowdSec Central API (interval: %s)", MetricsInterval)
log.Infof("Start send metrics to CrowdSec Central API (interval: %s)", a.metricsInterval)
ticker := time.NewTicker(a.metricsInterval)
for {
metrics, err := a.GetMetrics()
if err != nil {
log.Errorf("unable to get metrics (%s), will retry", err)
}
_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
if err != nil {
log.Errorf("capi metrics: failed: %s", err)
} else {
log.Infof("capi metrics: metrics sent successfully")
}
select {
case <-stop:
return
case <-ticker.C:
metrics, err := a.GetMetrics()
if err != nil {
log.Errorf("unable to get metrics (%s), will retry", err)
}
_, _, err = a.apiClient.Metrics.Add(context.Background(), metrics)
if err != nil {
log.Errorf("capi metrics: failed: %s", err)
} else {
log.Infof("capi metrics: metrics sent successfully")
}
continue
case <-a.metricsTomb.Dying(): // if one apic routine is dying, do we kill the others?
a.pullTomb.Kill(nil)
a.pushTomb.Kill(nil)
return nil
return
}
}
}

View file

@ -8,13 +8,13 @@ import (
"net/url"
"os"
"reflect"
"sort"
"sync"
"testing"
"time"
"github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/cstest"
"github.com/crowdsecurity/crowdsec/pkg/cwversion"
"github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
@ -24,23 +24,20 @@ import (
"github.com/jarcoal/httpmock"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/tomb.v2"
)
func getDBClient(t *testing.T) *database.Client {
t.Helper()
dbPath, err := os.CreateTemp("", "*sqlite")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
Type: "sqlite",
DbName: "crowdsec",
DbPath: dbPath.Name(),
})
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
return dbClient
}
@ -98,11 +95,9 @@ func assertTotalAlertCount(t *testing.T, dbClient *database.Client, count int) {
func TestAPICCAPIPullIsOld(t *testing.T) {
api := getAPIC(t)
isOld, err := api.CAPIPullIsOld()
if err != nil {
t.Fatal(err)
}
isOld, err := api.CAPIPullIsOld()
require.NoError(t, err)
assert.True(t, isOld)
decision := api.dbClient.Ent.Decision.Create().
@ -123,16 +118,13 @@ func TestAPICCAPIPullIsOld(t *testing.T) {
SaveX(context.Background())
isOld, err = api.CAPIPullIsOld()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assert.False(t, isOld)
}
func TestAPICFetchScenariosListFromDB(t *testing.T) {
api := getAPIC(t)
testCases := []struct {
tests := []struct {
name string
machineIDsWithScenarios map[string]string
expectedScenarios []string
@ -154,8 +146,10 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
},
}
for _, tc := range testCases {
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
api := getAPIC(t)
for machineID, scenarios := range tc.machineIDsWithScenarios {
api.dbClient.Ent.Machine.Create().
SetMachineId(machineID).
@ -164,17 +158,14 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
SetScenarios(scenarios).
ExecX(context.Background())
}
scenarios, err := api.FetchScenariosListFromDB()
for machineID := range tc.machineIDsWithScenarios {
api.dbClient.Ent.Machine.Delete().Where(machine.MachineIdEQ(machineID)).ExecX(context.Background())
}
if err != nil {
t.Fatal(err)
} else {
sort.Strings(scenarios)
sort.Strings(tc.expectedScenarios)
assert.Equal(t, scenarios, tc.expectedScenarios)
}
require.NoError(t, err)
assert.ElementsMatch(t, tc.expectedScenarios, scenarios)
})
}
@ -196,11 +187,10 @@ func TestNewAPIC(t *testing.T) {
consoleConfig *csconfig.ConsoleConfig
}
tests := []struct {
name string
args args
wantErr bool
errorContains string
action func()
name string
args args
expectedErr string
action func()
}{
{
name: "simple",
@ -217,20 +207,16 @@ func TestNewAPIC(t *testing.T) {
dbClient: getDBClient(t),
consoleConfig: LoadTestConfig().API.Server.ConsoleConfig,
},
wantErr: true,
errorContains: "first path segment in URL cannot contain colon",
expectedErr: "first path segment in URL cannot contain colon",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
setConfig()
tt.action()
_, err := NewAPIC(testConfig, tt.args.dbClient, tt.args.consoleConfig)
if tt.wantErr {
assert.ErrorContains(t, err, tt.errorContains)
} else {
assert.NoError(t, err)
}
tc.action()
_, err := NewAPIC(testConfig, tc.args.dbClient, tc.args.consoleConfig)
cstest.RequireErrorContains(t, err, tc.expectedErr)
})
}
}
@ -268,17 +254,16 @@ func TestAPICHandleDeletedDecisions(t *testing.T) {
}}, deleteCounters)
assert.NoError(t, err)
assert.Equal(t, nbDeleted, 2)
assert.Equal(t, deleteCounters[SCOPE_CAPI]["all"], 2)
assert.Equal(t, 2, nbDeleted)
assert.Equal(t, 2, deleteCounters[SCOPE_CAPI]["all"])
}
func TestAPICGetMetrics(t *testing.T) {
api := getAPIC(t)
cleanUp := func() {
cleanUp := func(api *apic) {
api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
}
testCases := []struct {
tests := []struct {
name string
machineIDs []string
bouncers []string
@ -322,11 +307,13 @@ func TestAPICGetMetrics(t *testing.T) {
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
cleanUp()
for i, machineID := range testCase.machineIDs {
api.dbClient.Ent.Machine.Create().
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
apiClient := getAPIC(t)
cleanUp(apiClient)
for i, machineID := range tc.machineIDs {
apiClient.dbClient.Ent.Machine.Create().
SetMachineId(machineID).
SetPassword(testPassword.String()).
SetIpAddress(fmt.Sprintf("1.2.3.%d", i)).
@ -336,8 +323,8 @@ func TestAPICGetMetrics(t *testing.T) {
ExecX(context.Background())
}
for i, bouncerName := range testCase.bouncers {
api.dbClient.Ent.Bouncer.Create().
for i, bouncerName := range tc.bouncers {
apiClient.dbClient.Ent.Bouncer.Create().
SetIPAddress(fmt.Sprintf("1.2.3.%d", i)).
SetName(bouncerName).
SetAPIKey("foobar").
@ -346,19 +333,17 @@ func TestAPICGetMetrics(t *testing.T) {
ExecX(context.Background())
}
if foundMetrics, err := api.GetMetrics(); err != nil {
t.Fatal(err)
} else {
assert.Equal(t, foundMetrics.Bouncers, testCase.expectedMetric.Bouncers)
assert.Equal(t, foundMetrics.Machines, testCase.expectedMetric.Machines)
foundMetrics, err := apiClient.GetMetrics()
require.NoError(t, err)
assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers)
assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines)
}
})
}
}
func TestCreateAlertsForDecision(t *testing.T) {
httpBfDecisionList := &models.Decision{
Origin: &SCOPE_LISTS,
Scenario: types.StrPtr("crowdsecurity/http-bf"),
@ -427,10 +412,11 @@ func TestCreateAlertsForDecision(t *testing.T) {
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := createAlertsForDecisions(tt.args.decisions); !reflect.DeepEqual(got, tt.want) {
t.Errorf("createAlertsForDecisions() = %v, want %v", got, tt.want)
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
if got := createAlertsForDecisions(tc.args.decisions); !reflect.DeepEqual(got, tc.want) {
t.Errorf("createAlertsForDecisions() = %v, want %v", got, tc.want)
}
})
}
@ -503,11 +489,12 @@ func TestFillAlertsWithDecisions(t *testing.T) {
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
add_counters, _ := makeAddAndDeleteCounters()
if got := fillAlertsWithDecisions(tt.args.alerts, tt.args.decisions, add_counters); !reflect.DeepEqual(got, tt.want) {
t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tt.want)
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
addCounters, _ := makeAddAndDeleteCounters()
if got := fillAlertsWithDecisions(tc.args.alerts, tc.args.decisions, addCounters); !reflect.DeepEqual(got, tc.want) {
t.Errorf("fillAlertsWithDecisions() = %v, want %v", got, tc.want)
}
})
}
@ -586,24 +573,19 @@ func TestAPICPullTop(t *testing.T) {
),
))
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
apic, err := apiclient.NewDefaultClient(
url,
"/api",
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
api.apiClient = apic
err = api.PullTop()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assertTotalDecisionCount(t, api.dbClient, 5)
assertTotalValidDecisionCount(t, api.dbClient, 4)
@ -619,24 +601,23 @@ func TestAPICPullTop(t *testing.T) {
for _, alert := range alerts {
alertScenario[alert.SourceScope]++
}
assert.Equal(t, len(alertScenario), 3)
assert.Equal(t, alertScenario[SCOPE_CAPI_ALIAS], 1)
assert.Equal(t, alertScenario["lists:crowdsecurity/ssh-bf"], 1)
assert.Equal(t, alertScenario["lists:crowdsecurity/http-bf"], 1)
assert.Equal(t, 3, len(alertScenario))
assert.Equal(t, 1, alertScenario[SCOPE_CAPI_ALIAS])
assert.Equal(t, 1, alertScenario["lists:crowdsecurity/ssh-bf"])
assert.Equal(t, 1, alertScenario["lists:crowdsecurity/http-bf"])
for _, decisions := range validDecisions {
decisionScenarioFreq[decisions.Scenario]++
}
assert.Equal(t, decisionScenarioFreq["crowdsecurity/http-bf"], 1)
assert.Equal(t, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1)
assert.Equal(t, decisionScenarioFreq["crowdsecurity/test1"], 1)
assert.Equal(t, decisionScenarioFreq["crowdsecurity/test2"], 1)
assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/http-bf"], 1)
assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/ssh-bf"], 1)
assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test1"], 1)
assert.Equal(t, 1, decisionScenarioFreq["crowdsecurity/test2"], 1)
}
func TestAPICPush(t *testing.T) {
testCases := []struct {
tests := []struct {
name string
alerts []*models.Alert
expectedCalls int
@ -683,14 +664,14 @@ func TestAPICPush(t *testing.T) {
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
api := getAPIC(t)
api.pushInterval = time.Millisecond
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
httpmock.Activate()
defer httpmock.DeactivateAndReset()
apic, err := apiclient.NewDefaultClient(
@ -699,31 +680,28 @@ func TestAPICPush(t *testing.T) {
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
api.apiClient = apic
httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/signals", httpmock.NewBytesResponder(200, []byte{}))
go func() {
api.alertToPush <- testCase.alerts
api.alertToPush <- tc.alerts
time.Sleep(time.Second)
api.Shutdown()
}()
if err := api.Push(); err != nil {
t.Fatal(err)
}
assert.Equal(t, httpmock.GetTotalCallCount(), testCase.expectedCalls)
err = api.Push()
require.NoError(t, err)
assert.Equal(t, tc.expectedCalls, httpmock.GetTotalCallCount())
})
}
}
func TestAPICSendMetrics(t *testing.T) {
api := getAPIC(t)
testCases := []struct {
tests := []struct {
name string
duration time.Duration
expectedCalls int
setUp func()
setUp func(*apic)
metricsInterval time.Duration
}{
{
@ -731,14 +709,15 @@ func TestAPICSendMetrics(t *testing.T) {
duration: time.Millisecond * 30,
metricsInterval: time.Millisecond * 5,
expectedCalls: 5,
setUp: func() {},
setUp: func(api *apic) {},
},
{
name: "with some metrics",
duration: time.Millisecond * 30,
metricsInterval: time.Millisecond * 5,
expectedCalls: 5,
setUp: func() {
setUp: func(api *apic) {
api.dbClient.Ent.Machine.Delete().ExecX(context.Background())
api.dbClient.Ent.Machine.Create().
SetMachineId("1234").
SetPassword(testPassword.String()).
@ -748,6 +727,7 @@ func TestAPICSendMetrics(t *testing.T) {
SetUpdatedAt(time.Time{}).
ExecX(context.Background())
api.dbClient.Ent.Bouncer.Delete().ExecX(context.Background())
api.dbClient.Ent.Bouncer.Create().
SetIPAddress("1.2.3.6").
SetName("someBouncer").
@ -758,44 +738,49 @@ func TestAPICSendMetrics(t *testing.T) {
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
api = getAPIC(t)
api.pushInterval = time.Millisecond
httpmock.RegisterResponder("POST", "http://api.crowdsec.net/api/metrics/", httpmock.NewBytesResponder(200, []byte{}))
httpmock.Activate()
defer httpmock.Deactivate()
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
if err != nil {
t.Fatal(err)
}
httpmock.Activate()
defer httpmock.DeactivateAndReset()
apic, err := apiclient.NewDefaultClient(
require.NoError(t, err)
apiClient, err := apiclient.NewDefaultClient(
url,
"/api",
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil,
)
if err != nil {
t.Fatal(err)
}
api.apiClient = apic
api.metricsInterval = testCase.metricsInterval
httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, []byte{}))
testCase.setUp()
require.NoError(t, err)
go func() {
if err := api.SendMetrics(); err != nil {
panic(err)
}
}()
time.Sleep(testCase.duration)
assert.LessOrEqual(t, absDiff(testCase.expectedCalls, httpmock.GetTotalCallCount()), 2)
api := getAPIC(t)
api.pushInterval = time.Millisecond
api.apiClient = apiClient
api.metricsInterval = tc.metricsInterval
tc.setUp(api)
stop := make(chan bool)
httpmock.ZeroCallCounters()
go api.SendMetrics(stop)
time.Sleep(tc.duration)
stop <- true
info := httpmock.GetCallCountInfo()
noResponderCalls := info["NO_RESPONDER"]
responderCalls := info["POST http://api.crowdsec.net/api/metrics/"]
assert.LessOrEqual(t, absDiff(tc.expectedCalls, responderCalls), 2)
assert.Zero(t, noResponderCalls)
})
}
}
func TestAPICPull(t *testing.T) {
api := getAPIC(t)
testCases := []struct {
tests := []struct {
name string
setUp func()
expectedDecisionCount int
@ -820,14 +805,13 @@ func TestAPICPull(t *testing.T) {
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
api = getAPIC(t)
api.pullInterval = time.Millisecond
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
httpmock.Activate()
defer httpmock.DeactivateAndReset()
apic, err := apiclient.NewDefaultClient(
@ -836,9 +820,7 @@ func TestAPICPull(t *testing.T) {
fmt.Sprintf("crowdsec/%s", cwversion.VersionStr()),
nil,
)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
api.apiClient = apic
httpmock.RegisterNoResponder(httpmock.NewBytesResponder(200, jsonMarshalX(
models.DecisionsStreamResponse{
@ -854,7 +836,7 @@ func TestAPICPull(t *testing.T) {
},
},
)))
testCase.setUp()
tc.setUp()
var buf bytes.Buffer
go func() {
logrus.SetOutput(&buf)
@ -865,15 +847,14 @@ func TestAPICPull(t *testing.T) {
//Slightly long because the CI runner for windows are slow, and this can lead to random failure
time.Sleep(time.Millisecond * 500)
logrus.SetOutput(os.Stderr)
assert.Contains(t, buf.String(), testCase.logContains)
assertTotalDecisionCount(t, api.dbClient, testCase.expectedDecisionCount)
assert.Contains(t, buf.String(), tc.logContains)
assertTotalDecisionCount(t, api.dbClient, tc.expectedDecisionCount)
})
}
}
func TestShouldShareAlert(t *testing.T) {
testCases := []struct {
tests := []struct {
name string
consoleConfig *csconfig.ConsoleConfig
alert *models.Alert
@ -948,10 +929,11 @@ func TestShouldShareAlert(t *testing.T) {
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ret := shouldShareAlert(testCase.alert, testCase.consoleConfig)
assert.Equal(t, ret, testCase.expectedRet)
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
ret := shouldShareAlert(tc.alert, tc.consoleConfig)
assert.Equal(t, tc.expectedRet, ret)
})
}
}

View file

@ -311,10 +311,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
return nil
})
s.apic.metricsTomb.Go(func() error {
if err := s.apic.SendMetrics(); err != nil {
log.Errorf("capi metrics: %s", err)
return err
}
s.apic.SendMetrics(make(chan bool))
return nil
})
}

View file

@ -7,6 +7,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Copy(sourceFile string, destinationFile string) error {
@ -110,6 +111,8 @@ func CopyDir(src string, dest string) error {
}
func AssertErrorContains(t *testing.T, err error, expectedErr string) {
t.Helper()
if expectedErr != "" {
assert.ErrorContains(t, err, expectedErr)
return
@ -117,3 +120,14 @@ func AssertErrorContains(t *testing.T, err error, expectedErr string) {
assert.NoError(t, err)
}
func RequireErrorContains(t *testing.T, err error, expectedErr string) {
t.Helper()
if expectedErr != "" {
require.ErrorContains(t, err, expectedErr)
return
}
require.NoError(t, err)
}