diff --git a/.github/workflows/bats-hub.yml b/.github/workflows/bats-hub.yml index 5520b0dac..faff0a799 100644 --- a/.github/workflows/bats-hub.yml +++ b/.github/workflows/bats-hub.yml @@ -36,7 +36,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v4 with: - go-version: "1.21.5" + go-version: "1.21.6" - name: "Install bats dependencies" env: diff --git a/.github/workflows/bats-mysql.yml b/.github/workflows/bats-mysql.yml index e0ea4001f..bceaf6ce7 100644 --- a/.github/workflows/bats-mysql.yml +++ b/.github/workflows/bats-mysql.yml @@ -14,7 +14,7 @@ jobs: build: strategy: matrix: - go-version: ["1.21.5"] + go-version: ["1.21.6"] name: "Build + tests" runs-on: ubuntu-latest diff --git a/.github/workflows/bats-postgres.yml b/.github/workflows/bats-postgres.yml index af870ed83..b058f79b2 100644 --- a/.github/workflows/bats-postgres.yml +++ b/.github/workflows/bats-postgres.yml @@ -10,7 +10,7 @@ jobs: build: strategy: matrix: - go-version: ["1.21.5"] + go-version: ["1.21.6"] name: "Build + tests" runs-on: ubuntu-latest diff --git a/.github/workflows/bats-sqlite-coverage.yml b/.github/workflows/bats-sqlite-coverage.yml index 2f2b15e11..ee1908b92 100644 --- a/.github/workflows/bats-sqlite-coverage.yml +++ b/.github/workflows/bats-sqlite-coverage.yml @@ -11,7 +11,7 @@ jobs: build: strategy: matrix: - go-version: ["1.21.5"] + go-version: ["1.21.6"] name: "Build + tests" runs-on: ubuntu-latest diff --git a/.github/workflows/ci-windows-build-msi.yml b/.github/workflows/ci-windows-build-msi.yml index deff03063..bfb2cdaca 100644 --- a/.github/workflows/ci-windows-build-msi.yml +++ b/.github/workflows/ci-windows-build-msi.yml @@ -23,7 +23,7 @@ jobs: build: strategy: matrix: - go-version: ["1.21.5"] + go-version: ["1.21.6"] name: Build runs-on: windows-2019 diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index ce9482274..f23355b49 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -74,7 +74,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v4 with: - go-version: "1.21.0" + go-version: "1.21.6" cache-dependency-path: "**/go.sum" - run: | diff --git a/.github/workflows/go-tests-windows.yml b/.github/workflows/go-tests-windows.yml index 68284c10b..3f36327f3 100644 --- a/.github/workflows/go-tests-windows.yml +++ b/.github/workflows/go-tests-windows.yml @@ -22,7 +22,7 @@ jobs: build: strategy: matrix: - go-version: ["1.21.5"] + go-version: ["1.21.6"] name: "Build + tests" runs-on: windows-2022 diff --git a/.github/workflows/go-tests.yml b/.github/workflows/go-tests.yml index 05104caf8..f6d2f9c98 100644 --- a/.github/workflows/go-tests.yml +++ b/.github/workflows/go-tests.yml @@ -126,7 +126,7 @@ jobs: - name: "Set up Go" uses: actions/setup-go@v4 with: - go-version: "1.21.5" + go-version: "1.21.6" - name: Create localstack streams run: | diff --git a/.github/workflows/publish_docker-image_on_master-debian.yml b/.github/workflows/publish_docker-image_on_master-debian.yml index 88076157c..17332adf0 100644 --- a/.github/workflows/publish_docker-image_on_master-debian.yml +++ b/.github/workflows/publish_docker-image_on_master-debian.yml @@ -19,6 +19,7 @@ jobs: push_to_registry: name: Push Debian Docker image to Docker Hub runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'crowdsecurity' }} steps: - name: Check out the repo diff --git a/.github/workflows/publish_docker-image_on_master.yml b/.github/workflows/publish_docker-image_on_master.yml index 6cab486b0..345290200 100644 --- a/.github/workflows/publish_docker-image_on_master.yml +++ b/.github/workflows/publish_docker-image_on_master.yml @@ -19,6 +19,7 @@ jobs: push_to_registry: name: Push Docker image to Docker Hub runs-on: ubuntu-latest + if: ${{ github.repository_owner == 'crowdsecurity' }} steps: - name: Check out the repo diff --git a/.github/workflows/release_publish-package.yml b/.github/workflows/release_publish-package.yml index b2784d391..855915824 100644 --- a/.github/workflows/release_publish-package.yml +++ b/.github/workflows/release_publish-package.yml @@ -14,7 +14,7 @@ jobs: build: strategy: matrix: - go-version: ["1.21.5"] + go-version: ["1.21.6"] name: Build and upload binary package runs-on: ubuntu-latest diff --git a/.golangci.yml b/.golangci.yml index 13d8b4534..5c0bab58c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,6 +9,13 @@ run: - pkg/yamlpatch/merge_test.go linters-settings: + gci: + sections: + - standard + - default + - prefix(github.com/crowdsecurity) + - prefix(github.com/crowdsecurity/crowdsec) + gocyclo: min-complexity: 30 @@ -235,42 +242,6 @@ issues: # Will fix, trivial - just beware of merge conflicts - - linters: - - testifylint - text: "expected-actual: need to reverse actual and expected values" - - - linters: - - testifylint - text: "bool-compare: use assert.False" - - - linters: - - testifylint - text: "len: use assert.Len" - - - linters: - - testifylint - text: "bool-compare: use assert.True" - - - linters: - - testifylint - text: "bool-compare: use require.True" - - - linters: - - testifylint - text: "require-error: for error assertions use require" - - - linters: - - testifylint - text: "error-nil: use assert.NoError" - - - linters: - - testifylint - text: "error-nil: use assert.Error" - - - linters: - - testifylint - text: "empty: use assert.Empty" - - linters: - perfsprint text: "fmt.Sprintf can be replaced .*" @@ -279,10 +250,6 @@ issues: # Will fix, easy but some neurons required # - - linters: - - testifylint - text: "float-compare: use assert.InEpsilon .*or InDelta.*" - - linters: - errorlint text: "non-wrapping format verb for fmt.Errorf. Use `%w` to format errors" diff --git a/Dockerfile b/Dockerfile index 3b95ca3c4..7470beb57 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,6 @@ # vim: set ft=dockerfile: -ARG GOVERSION=1.21.5 +ARG GOVERSION=1.21.6 +ARG BUILD_VERSION FROM golang:${GOVERSION}-alpine3.18 AS build @@ -7,6 +8,7 @@ WORKDIR /go/src/crowdsec # We like to choose the release of re2 to use, and Alpine does not ship a static version anyway. ENV RE2_VERSION=2023-03-01 +ENV BUILD_VERSION=${BUILD_VERSION} # wizard.sh requires GNU coreutils RUN apk add --no-cache git g++ gcc libc-dev make bash gettext binutils-gold coreutils pkgconfig && \ diff --git a/Dockerfile.debian b/Dockerfile.debian index 2f116aa73..bc5b0aa2d 100644 --- a/Dockerfile.debian +++ b/Dockerfile.debian @@ -1,5 +1,6 @@ # vim: set ft=dockerfile: -ARG GOVERSION=1.21.5 +ARG GOVERSION=1.21.6 +ARG BUILD_VERSION FROM golang:${GOVERSION}-bookworm AS build @@ -10,6 +11,7 @@ ENV DEBCONF_NOWARNINGS="yes" # We like to choose the release of re2 to use, the debian version is usually older. ENV RE2_VERSION=2023-03-01 +ENV BUILD_VERSION=${BUILD_VERSION} # wizard.sh requires GNU coreutils RUN apt-get update && \ diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 1b7cd4960..82caba42b 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -25,9 +25,9 @@ stages: custom: 'tool' arguments: 'install --global SignClient --version 1.3.155' - task: GoTool@0 - displayName: "Install Go 1.20" + displayName: "Install Go" inputs: - version: '1.21.5' + version: '1.21.6' - pwsh: | choco install -y make diff --git a/cmd/crowdsec-cli/capi.go b/cmd/crowdsec-cli/capi.go index ea1d127cc..358d91ee2 100644 --- a/cmd/crowdsec-cli/capi.go +++ b/cmd/crowdsec-cli/capi.go @@ -38,7 +38,7 @@ func (cli cliCapi) NewCommand() *cobra.Command { Short: "Manage interaction with Central API (CAPI)", Args: cobra.MinimumNArgs(1), DisableAutoGenTag: true, - PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + PersistentPreRunE: func(_ *cobra.Command, _ []string) error { if err := require.LAPI(csConfig); err != nil { return err } @@ -58,15 +58,17 @@ func (cli cliCapi) NewCommand() *cobra.Command { } func (cli cliCapi) NewRegisterCmd() *cobra.Command { - var capiUserPrefix string - var outputFile string + var ( + capiUserPrefix string + outputFile string + ) var cmd = &cobra.Command{ Use: "register", Short: "Register to Central API (CAPI)", Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(_ *cobra.Command, _ []string) error { var err error capiUser, err := generateID(capiUserPrefix) if err != nil { @@ -115,7 +117,7 @@ func (cli cliCapi) NewRegisterCmd() *cobra.Command { } log.Printf("Central API credentials written to '%s'", dumpFile) } else { - fmt.Printf("%s\n", string(apiConfigDump)) + fmt.Println(string(apiConfigDump)) } log.Warning(ReloadMessage()) @@ -126,6 +128,7 @@ func (cli cliCapi) NewRegisterCmd() *cobra.Command { cmd.Flags().StringVarP(&outputFile, "file", "f", "", "output file destination") cmd.Flags().StringVar(&capiUserPrefix, "schmilblick", "", "set a schmilblick (use in tests only)") + if err := cmd.Flags().MarkHidden("schmilblick"); err != nil { log.Fatalf("failed to hide flag: %s", err) } @@ -134,18 +137,14 @@ func (cli cliCapi) NewRegisterCmd() *cobra.Command { } func (cli cliCapi) NewStatusCmd() *cobra.Command { - var cmd = &cobra.Command{ + cmd := &cobra.Command{ Use: "status", Short: "Check status with the Central API (CAPI)", Args: cobra.MinimumNArgs(0), DisableAutoGenTag: true, - RunE: func(cmd *cobra.Command, args []string) error { - if csConfig.API.Server.OnlineClient == nil { - return fmt.Errorf("please provide credentials for the Central API (CAPI) in '%s'", csConfig.API.Server.OnlineClient.CredentialsFilePath) - } - - if csConfig.API.Server.OnlineClient.Credentials == nil { - return fmt.Errorf("no credentials for Central API (CAPI) in '%s'", csConfig.API.Server.OnlineClient.CredentialsFilePath) + RunE: func(_ *cobra.Command, _ []string) error { + if err := require.CAPIRegistered(csConfig); err != nil { + return err } password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password) diff --git a/cmd/crowdsec-cli/explain.go b/cmd/crowdsec-cli/explain.go index 3e7f48fa0..e6dd3598f 100644 --- a/cmd/crowdsec-cli/explain.go +++ b/cmd/crowdsec-cli/explain.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/hubtest" ) @@ -35,7 +36,7 @@ func GetLineCountForFile(filepath string) (int, error) { return lc, nil } -type cliExplain struct {} +type cliExplain struct{} func NewCLIExplain() *cliExplain { return &cliExplain{} @@ -109,6 +110,7 @@ tail -n 5 myfile.log | cscli explain --type nginx -f - flags.Bool("failures", false, "Only show failed lines") flags.Bool("only-successful-parsers", false, "Only show successful parsers") flags.String("crowdsec", "crowdsec", "Path to crowdsec") + flags.Bool("no-clean", false, "Don't clean runtime environment after tests") return cmd } @@ -136,13 +138,18 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { return err } - opts := hubtest.DumpOpts{} + opts := dumps.DumpOpts{} opts.Details, err = flags.GetBool("verbose") if err != nil { return err } + no_clean, err := flags.GetBool("no-clean") + if err != nil { + return err + } + opts.SkipOk, err = flags.GetBool("failures") if err != nil { return err @@ -172,6 +179,9 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %s", err) } defer func() { + if no_clean { + return + } if _, err := os.Stat(dir); !os.IsNotExist(err) { if err := os.RemoveAll(dir); err != nil { log.Errorf("unable to delete temporary directory '%s': %s", dir, err) @@ -254,17 +264,17 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { parserDumpFile := filepath.Join(dir, hubtest.ParserResultFileName) bucketStateDumpFile := filepath.Join(dir, hubtest.BucketPourResultFileName) - parserDump, err := hubtest.LoadParserDump(parserDumpFile) + parserDump, err := dumps.LoadParserDump(parserDumpFile) if err != nil { return fmt.Errorf("unable to load parser dump result: %s", err) } - bucketStateDump, err := hubtest.LoadBucketPourDump(bucketStateDumpFile) + bucketStateDump, err := dumps.LoadBucketPourDump(bucketStateDumpFile) if err != nil { return fmt.Errorf("unable to load bucket dump result: %s", err) } - hubtest.DumpTree(*parserDump, *bucketStateDump, opts) + dumps.DumpTree(*parserDump, *bucketStateDump, opts) return nil } diff --git a/cmd/crowdsec-cli/hub.go b/cmd/crowdsec-cli/hub.go index 331cb4cea..3a2913f05 100644 --- a/cmd/crowdsec-cli/hub.go +++ b/cmd/crowdsec-cli/hub.go @@ -155,6 +155,7 @@ func (cli cliHub) upgrade(cmd *cobra.Command, args []string) error { if err != nil { return err } + if didUpdate { updated++ } @@ -191,18 +192,21 @@ func (cli cliHub) types(cmd *cobra.Command, args []string) error { if err != nil { return err } + fmt.Print(string(s)) case "json": jsonStr, err := json.Marshal(cwhub.ItemTypes) if err != nil { return err } + fmt.Println(string(jsonStr)) case "raw": for _, itemType := range cwhub.ItemTypes { fmt.Println(itemType) } } + return nil } diff --git a/cmd/crowdsec-cli/hubappsec.go b/cmd/crowdsec-cli/hubappsec.go index 8371f49b3..ff41ad5f9 100644 --- a/cmd/crowdsec-cli/hubappsec.go +++ b/cmd/crowdsec-cli/hubappsec.go @@ -48,10 +48,11 @@ cscli appsec-configs list crowdsecurity/vpatch`, func NewCLIAppsecRule() *cliItem { inspectDetail := func(item *cwhub.Item) error { - //Only show the converted rules in human mode + // Only show the converted rules in human mode if csConfig.Cscli.Output != "human" { return nil } + appsecRule := appsec.AppsecCollectionConfig{} yamlContent, err := os.ReadFile(item.State.LocalPath) @@ -71,8 +72,16 @@ func NewCLIAppsecRule() *cliItem { if err != nil { return fmt.Errorf("unable to convert rule %s : %s", rule.Name, err) } + fmt.Println(convertedRule) } + + switch ruleType { //nolint:gocritic + case appsec_rule.ModsecurityRuleType: + for _, rule := range appsecRule.SecLangRules { + fmt.Println(rule) + } + } } return nil diff --git a/cmd/crowdsec-cli/hubtest.go b/cmd/crowdsec-cli/hubtest.go index daf22fb5c..1860540e7 100644 --- a/cmd/crowdsec-cli/hubtest.go +++ b/cmd/crowdsec-cli/hubtest.go @@ -16,6 +16,7 @@ import ( "github.com/spf13/cobra" "gopkg.in/yaml.v2" + "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/hubtest" ) @@ -679,8 +680,8 @@ func (cli cliHubTest) NewExplainCmd() *cobra.Command { return fmt.Errorf("unable to load scenario result after run: %s", err) } } - opts := hubtest.DumpOpts{} - hubtest.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts) + opts := dumps.DumpOpts{} + dumps.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts) } return nil diff --git a/cmd/crowdsec-cli/item_suggest.go b/cmd/crowdsec-cli/item_suggest.go index ac9d9da50..d3beee721 100644 --- a/cmd/crowdsec-cli/item_suggest.go +++ b/cmd/crowdsec-cli/item_suggest.go @@ -2,11 +2,11 @@ package main import ( "fmt" + "slices" "strings" "github.com/agext/levenshtein" "github.com/spf13/cobra" - "slices" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/cwhub" diff --git a/cmd/crowdsec-cli/items.go b/cmd/crowdsec-cli/items.go index 560f9dc2a..a1d079747 100644 --- a/cmd/crowdsec-cli/items.go +++ b/cmd/crowdsec-cli/items.go @@ -7,10 +7,10 @@ import ( "io" "os" "path/filepath" + "slices" "strings" "gopkg.in/yaml.v3" - "slices" "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) diff --git a/cmd/crowdsec-cli/main.go b/cmd/crowdsec-cli/main.go index 380e984e8..72b534f9b 100644 --- a/cmd/crowdsec-cli/main.go +++ b/cmd/crowdsec-cli/main.go @@ -2,12 +2,13 @@ package main import ( "os" + "slices" + "time" "github.com/fatih/color" cc "github.com/ivanpirog/coloredcobra" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" - "slices" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" @@ -107,7 +108,7 @@ var NoNeedConfig = []string{ func main() { // set the formatter asap and worry about level later - logFormatter := &log.TextFormatter{TimestampFormat: "2006-01-02 15:04:05", FullTimestamp: true} + logFormatter := &log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true} log.SetFormatter(logFormatter) if err := fflag.RegisterAllFeatures(); err != nil { diff --git a/cmd/crowdsec/hook.go b/cmd/crowdsec/hook.go new file mode 100644 index 000000000..28515d9e4 --- /dev/null +++ b/cmd/crowdsec/hook.go @@ -0,0 +1,43 @@ +package main + +import ( + "io" + "os" + + log "github.com/sirupsen/logrus" +) + +type ConditionalHook struct { + Writer io.Writer + LogLevels []log.Level + Enabled bool +} + +func (hook *ConditionalHook) Fire(entry *log.Entry) error { + if hook.Enabled { + line, err := entry.String() + if err != nil { + return err + } + + _, err = hook.Writer.Write([]byte(line)) + + return err + } + + return nil +} + +func (hook *ConditionalHook) Levels() []log.Level { + return hook.LogLevels +} + +// The primal logging hook is set up before parsing config.yaml. +// Once config.yaml is parsed, the primal hook is disabled if the +// configured logger is writing to stderr. Otherwise it's used to +// report fatal errors and panics to stderr in addition to the log file. +var primalHook = &ConditionalHook{ + Writer: os.Stderr, + LogLevels: []log.Level{log.FatalLevel, log.PanicLevel}, + Enabled: true, +} diff --git a/cmd/crowdsec/main.go b/cmd/crowdsec/main.go index 362ed1869..05a6db01d 100644 --- a/cmd/crowdsec/main.go +++ b/cmd/crowdsec/main.go @@ -80,11 +80,13 @@ func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error { err error files []string ) + for _, hubScenarioItem := range hub.GetItemMap(cwhub.SCENARIOS) { if hubScenarioItem.State.Installed { files = append(files, hubScenarioItem.State.LocalPath) } } + buckets = leakybucket.NewBuckets() log.Infof("Loading %d scenario files", len(files)) @@ -99,6 +101,7 @@ func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error { holders[holderIndex].Profiling = true } } + return nil } @@ -143,8 +146,10 @@ func (l labelsMap) Set(label string) error { if len(split) != 2 { return fmt.Errorf("invalid format for label '%s', must be key:value", pair) } + l[split[0]] = split[1] } + return nil } @@ -168,9 +173,11 @@ func (f *Flags) Parse() { flag.BoolVar(&f.DisableAPI, "no-api", false, "disable local API") flag.BoolVar(&f.DisableCAPI, "no-capi", false, "disable communication with Central API") flag.BoolVar(&f.OrderEvent, "order-event", false, "enforce event ordering with significant performance cost") + if runtime.GOOS == "windows" { flag.StringVar(&f.WinSvc, "winsvc", "", "Windows service Action: Install, Remove etc..") } + flag.StringVar(&dumpFolder, "dump-data", "", "dump parsers/buckets raw outputs") flag.Parse() } @@ -205,6 +212,7 @@ func newLogLevel(curLevelPtr *log.Level, f *Flags) *log.Level { // avoid returning a new ptr to the same value return curLevelPtr } + return &ret } @@ -238,6 +246,8 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo return nil, err } + primalHook.Enabled = (cConfig.Common.LogMedia != "stdout") + if err := csconfig.LoadFeatureFlagsFile(configFile, log.StandardLogger()); err != nil { return nil, err } @@ -282,6 +292,7 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo if cConfig.DisableAPI { cConfig.Common.Daemonize = false } + log.Infof("single file mode : log_media=%s daemonize=%t", cConfig.Common.LogMedia, cConfig.Common.Daemonize) } @@ -291,6 +302,7 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo if cConfig.Common.Daemonize && runtime.GOOS == "windows" { log.Debug("Daemonization is not supported on Windows, disabling") + cConfig.Common.Daemonize = false } @@ -308,6 +320,8 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo var crowdsecT0 time.Time func main() { + log.AddHook(primalHook) + if err := fflag.RegisterAllFeatures(); err != nil { log.Fatalf("failed to register features: %s", err) } @@ -342,5 +356,6 @@ func main() { if err != nil { log.Fatal(err) } + os.Exit(0) } diff --git a/cmd/crowdsec/output.go b/cmd/crowdsec/output.go index bfedbb979..ad53ce4c8 100644 --- a/cmd/crowdsec/output.go +++ b/cmd/crowdsec/output.go @@ -76,6 +76,15 @@ func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky return fmt.Errorf("loading list of installed hub scenarios: %w", err) } + appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES) + if err != nil { + return fmt.Errorf("loading list of installed hub appsec rules: %w", err) + } + + installedScenariosAndAppsecRules := make([]string, 0, len(scenarios)+len(appsecRules)) + installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, scenarios...) + installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, appsecRules...) + apiURL, err := url.Parse(apiConfig.URL) if err != nil { return fmt.Errorf("parsing api url ('%s'): %w", apiConfig.URL, err) @@ -87,14 +96,27 @@ func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky password := strfmt.Password(apiConfig.Password) Client, err := apiclient.NewClient(&apiclient.Config{ - MachineID: apiConfig.Login, - Password: password, - Scenarios: scenarios, - UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), - URL: apiURL, - PapiURL: papiURL, - VersionPrefix: "v1", - UpdateScenario: func() ([]string, error) {return hub.GetInstalledItemNames(cwhub.SCENARIOS)}, + MachineID: apiConfig.Login, + Password: password, + Scenarios: installedScenariosAndAppsecRules, + UserAgent: fmt.Sprintf("crowdsec/%s", version.String()), + URL: apiURL, + PapiURL: papiURL, + VersionPrefix: "v1", + UpdateScenario: func() ([]string, error) { + scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS) + if err != nil { + return nil, err + } + appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES) + if err != nil { + return nil, err + } + ret := make([]string, 0, len(scenarios)+len(appsecRules)) + ret = append(ret, scenarios...) + ret = append(ret, appsecRules...) + return ret, nil + }, }) if err != nil { return fmt.Errorf("new client api: %w", err) @@ -102,7 +124,7 @@ func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky authResp, _, err := Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &apiConfig.Login, Password: &password, - Scenarios: scenarios, + Scenarios: installedScenariosAndAppsecRules, }) if err != nil { return fmt.Errorf("authenticate watcher (%s): %w", apiConfig.Login, err) diff --git a/cmd/crowdsec/run_in_svc.go b/cmd/crowdsec/run_in_svc.go index 563f758c6..8b2cb2aee 100644 --- a/cmd/crowdsec/run_in_svc.go +++ b/cmd/crowdsec/run_in_svc.go @@ -4,10 +4,8 @@ package main import ( "fmt" - "os" log "github.com/sirupsen/logrus" - "github.com/sirupsen/logrus/hooks/writer" "github.com/crowdsecurity/go-cs-lib/trace" "github.com/crowdsecurity/go-cs-lib/version" @@ -24,16 +22,6 @@ func StartRunSvc() error { defer trace.CatchPanic("crowdsec/StartRunSvc") - // Set a default logger with level=fatal on stderr, - // in addition to the one we configure afterwards - log.AddHook(&writer.Hook{ - Writer: os.Stderr, - LogLevels: []log.Level{ - log.PanicLevel, - log.FatalLevel, - }, - }) - if cConfig, err = LoadConfig(flags.ConfigFile, flags.DisableAgent, flags.DisableAPI, false); err != nil { return err } @@ -46,6 +34,7 @@ func StartRunSvc() error { // Enable profiling early if cConfig.Prometheus != nil { var dbClient *database.Client + var err error if cConfig.DbConfig != nil { @@ -55,8 +44,11 @@ func StartRunSvc() error { return fmt.Errorf("unable to create database client: %s", err) } } + registerPrometheus(cConfig.Prometheus) + go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady) } + return Serve(cConfig, apiReady, agentReady) } diff --git a/go.mod b/go.mod index 634f90f5e..d61c191c1 100644 --- a/go.mod +++ b/go.mod @@ -24,9 +24,9 @@ require ( github.com/buger/jsonparser v1.1.1 github.com/c-robinson/iplib v1.0.3 github.com/cespare/xxhash/v2 v2.2.0 - github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f + github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 - github.com/crowdsecurity/go-cs-lib v0.0.5 + github.com/crowdsecurity/go-cs-lib v0.0.6 github.com/crowdsecurity/grokky v0.2.1 github.com/crowdsecurity/machineid v1.0.2 github.com/davecgh/go-spew v1.1.1 @@ -201,7 +201,7 @@ require ( go.mongodb.org/mongo-driver v1.9.4 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/net v0.19.0 // indirect - golang.org/x/sync v0.5.0 // indirect + golang.org/x/sync v0.6.0 // indirect golang.org/x/term v0.15.0 // indirect golang.org/x/time v0.3.0 // indirect golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901 // indirect diff --git a/go.sum b/go.sum index cb0f97b89..f5f61594e 100644 --- a/go.sum +++ b/go.sum @@ -100,10 +100,14 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f h1:FkOB9aDw0xzDd14pTarGRLsUNAymONq3dc7zhvsXElg= github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f/go.mod h1:TrU7Li+z2RHNrPy0TKJ6R65V6Yzpan2sTIRryJJyJso= +github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 h1:hyrYw3h8clMcRL2u5ooZ3tmwnmJftmhb9Ws1MKmavvI= +github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607/go.mod h1:br36fEqurGYZQGit+iDYsIzW0FF6VufMbDzyyLxEuPA= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen+7WnLs4xDScS/Ex988+id2k6mDf8psU= github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk= github.com/crowdsecurity/go-cs-lib v0.0.5 h1:eVLW+BRj3ZYn0xt5/xmgzfbbB8EBo32gM4+WpQQk2e8= github.com/crowdsecurity/go-cs-lib v0.0.5/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k= +github.com/crowdsecurity/go-cs-lib v0.0.6 h1:Ef6MylXe0GaJE9vrfvxEdbHb31+JUP1os+murPz7Pos= +github.com/crowdsecurity/go-cs-lib v0.0.6/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k= github.com/crowdsecurity/grokky v0.2.1 h1:t4VYnDlAd0RjDM2SlILalbwfCrQxtJSMGdQOR0zwkE4= github.com/crowdsecurity/grokky v0.2.1/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM= github.com/crowdsecurity/machineid v1.0.2 h1:wpkpsUghJF8Khtmn/tg6GxgdhLA1Xflerh5lirI+bdc= @@ -807,6 +811,8 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/mk/goversion.mk b/mk/goversion.mk index dd9954928..73e9a72e2 100644 --- a/mk/goversion.mk +++ b/mk/goversion.mk @@ -1,5 +1,5 @@ -BUILD_GOVERSION = $(subst go,,$(shell go env GOVERSION)) +BUILD_GOVERSION = $(subst go,,$(shell $(GO) env GOVERSION)) go_major_minor = $(subst ., ,$(BUILD_GOVERSION)) GO_MAJOR_VERSION = $(word 1, $(go_major_minor)) @@ -9,7 +9,7 @@ GO_VERSION_VALIDATION_ERR_MSG = Golang version ($(BUILD_GOVERSION)) is not suppo .PHONY: goversion -goversion: $(if $(findstring devel,$(shell go env GOVERSION)),goversion_devel,goversion_check) +goversion: $(if $(findstring devel,$(shell $(GO) env GOVERSION)),goversion_devel,goversion_check) .PHONY: goversion_devel diff --git a/mk/platform/unix_common.mk b/mk/platform/unix_common.mk index 8f06c9328..5e5b5de3a 100644 --- a/mk/platform/unix_common.mk +++ b/mk/platform/unix_common.mk @@ -8,7 +8,11 @@ MKDIR=mkdir -p GOOS ?= $(shell go env GOOS) # Current versioning information from env -BUILD_VERSION?=$(shell git describe --tags) +# The $(or) is used to ignore an empty BUILD_VERSION when it's an envvar, +# like inside a docker build: docker build --build-arg BUILD_VERSION=1.2.3 +# as opposed to a make parameter: make BUILD_VERSION=1.2.3 +BUILD_VERSION:=$(or $(BUILD_VERSION),$(shell git describe --tags --dirty)) + BUILD_TIMESTAMP=$(shell date +%F"_"%T) DEFAULT_CONFIGDIR?=/etc/crowdsec DEFAULT_DATADIR?=/var/lib/crowdsec/data diff --git a/pkg/acquisition/acquisition_test.go b/pkg/acquisition/acquisition_test.go index c1373a6c7..44b3878e1 100644 --- a/pkg/acquisition/acquisition_test.go +++ b/pkg/acquisition/acquisition_test.go @@ -40,15 +40,19 @@ func (f *MockSource) Configure(cfg []byte, logger *log.Entry) error { if err := f.UnmarshalConfig(cfg); err != nil { return err } + if f.Mode == "" { f.Mode = configuration.CAT_MODE } + if f.Mode != configuration.CAT_MODE && f.Mode != configuration.TAIL_MODE { return fmt.Errorf("mode %s is not supported", f.Mode) } + if f.Toto == "" { return fmt.Errorf("expect non-empty toto") } + return nil } func (f *MockSource) GetMode() string { return f.Mode } @@ -77,6 +81,7 @@ func appendMockSource() { if GetDataSourceIface("mock") == nil { AcquisitionSources["mock"] = func() DataSource { return &MockSource{} } } + if GetDataSourceIface("mock_cant_run") == nil { AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} } } @@ -84,6 +89,7 @@ func appendMockSource() { func TestDataSourceConfigure(t *testing.T) { appendMockSource() + tests := []struct { TestName string String string @@ -185,22 +191,22 @@ wowo: ajsajasjas switch tc.TestName { case "basic_valid_config": mock := (*ds).Dump().(*MockSource) - assert.Equal(t, mock.Toto, "test_value1") - assert.Equal(t, mock.Mode, "cat") - assert.Equal(t, mock.logger.Logger.Level, log.InfoLevel) - assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) + assert.Equal(t, "test_value1", mock.Toto) + assert.Equal(t, "cat", mock.Mode) + assert.Equal(t, log.InfoLevel, mock.logger.Logger.Level) + assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels) case "basic_debug_config": mock := (*ds).Dump().(*MockSource) - assert.Equal(t, mock.Toto, "test_value1") - assert.Equal(t, mock.Mode, "cat") - assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel) - assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) + assert.Equal(t, "test_value1", mock.Toto) + assert.Equal(t, "cat", mock.Mode) + assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level) + assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels) case "basic_tailmode_config": mock := (*ds).Dump().(*MockSource) - assert.Equal(t, mock.Toto, "test_value1") - assert.Equal(t, mock.Mode, "tail") - assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel) - assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"}) + assert.Equal(t, "test_value1", mock.Toto) + assert.Equal(t, "tail", mock.Mode) + assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level) + assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels) } }) } @@ -208,6 +214,7 @@ wowo: ajsajasjas func TestLoadAcquisitionFromFile(t *testing.T) { appendMockSource() + tests := []struct { TestName string Config csconfig.CrowdsecServiceCfg @@ -284,7 +291,6 @@ func TestLoadAcquisitionFromFile(t *testing.T) { assert.Len(t, dss, tc.ExpectedLen) }) - } } @@ -304,9 +310,11 @@ func (f *MockCat) Configure(cfg []byte, logger *log.Entry) error { if f.Mode == "" { f.Mode = configuration.CAT_MODE } + if f.Mode != configuration.CAT_MODE { return fmt.Errorf("mode %s is not supported", f.Mode) } + return nil } @@ -319,6 +327,7 @@ func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) erro evt.Line.Src = "test" out <- evt } + return nil } func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error { @@ -345,9 +354,11 @@ func (f *MockTail) Configure(cfg []byte, logger *log.Entry) error { if f.Mode == "" { f.Mode = configuration.TAIL_MODE } + if f.Mode != configuration.TAIL_MODE { return fmt.Errorf("mode %s is not supported", f.Mode) } + return nil } @@ -364,6 +375,7 @@ func (f *MockTail) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) erro out <- evt } <-t.Dying() + return nil } func (f *MockTail) CanRun() error { return nil } @@ -446,6 +458,7 @@ func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) out <- evt } t.Kill(fmt.Errorf("got error (tomb)")) + return fmt.Errorf("got error") } @@ -499,6 +512,7 @@ func (f *MockSourceByDSN) ConfigureByDSN(dsn string, labels map[string]string, l if dsn != "test_expect" { return fmt.Errorf("unexpected value") } + return nil } func (f *MockSourceByDSN) GetUuid() string { return "" } diff --git a/pkg/acquisition/modules/appsec/appsec.go b/pkg/acquisition/modules/appsec/appsec.go index 49830eb85..def6a6886 100644 --- a/pkg/acquisition/modules/appsec/appsec.go +++ b/pkg/acquisition/modules/appsec/appsec.go @@ -335,7 +335,7 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) { } // parse the request only once - parsedRequest, err := appsec.NewParsedRequestFromRequest(r) + parsedRequest, err := appsec.NewParsedRequestFromRequest(r, w.logger) if err != nil { w.logger.Errorf("%s", err) rw.WriteHeader(http.StatusInternalServerError) diff --git a/pkg/acquisition/modules/docker/docker_test.go b/pkg/acquisition/modules/docker/docker_test.go index 3c3eeefe6..c4d23168a 100644 --- a/pkg/acquisition/modules/docker/docker_test.go +++ b/pkg/acquisition/modules/docker/docker_test.go @@ -57,6 +57,7 @@ container_name: subLogger := log.WithFields(log.Fields{ "type": "docker", }) + for _, test := range tests { f := DockerSource{} err := f.Configure([]byte(test.config), subLogger) @@ -66,12 +67,15 @@ container_name: func TestConfigureDSN(t *testing.T) { log.Infof("Test 'TestConfigureDSN'") + var dockerHost string + if runtime.GOOS == "windows" { dockerHost = "npipe:////./pipe/docker_engine" } else { dockerHost = "unix:///var/run/podman/podman.sock" } + tests := []struct { name string dsn string @@ -106,6 +110,7 @@ func TestConfigureDSN(t *testing.T) { subLogger := log.WithFields(log.Fields{ "type": "docker", }) + for _, test := range tests { f := DockerSource{} err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") @@ -156,8 +161,11 @@ container_name_regexp: } for _, ts := range tests { - var logger *log.Logger - var subLogger *log.Entry + var ( + logger *log.Logger + subLogger *log.Entry + ) + if ts.expectedOutput != "" { logger.SetLevel(ts.logLevel) subLogger = logger.WithFields(log.Fields{ @@ -173,10 +181,12 @@ container_name_regexp: dockerTomb := tomb.Tomb{} out := make(chan types.Event) dockerSource := DockerSource{} + err := dockerSource.Configure([]byte(ts.config), subLogger) if err != nil { t.Fatalf("Unexpected error : %s", err) } + dockerSource.Client = new(mockDockerCli) actualLines := 0 readerTomb := &tomb.Tomb{} @@ -204,21 +214,23 @@ container_name_regexp: if err := readerTomb.Wait(); err != nil { t.Fatal(err) } + if ts.expectedLines != 0 { assert.Equal(t, ts.expectedLines, actualLines) } + err = streamTomb.Wait() if err != nil { t.Fatalf("docker acquisition error: %s", err) } } - } func (cli *mockDockerCli) ContainerList(ctx context.Context, options dockerTypes.ContainerListOptions) ([]dockerTypes.Container, error) { if readLogs == true { return []dockerTypes.Container{}, nil } + containers := make([]dockerTypes.Container, 0) container := &dockerTypes.Container{ ID: "12456", @@ -233,16 +245,20 @@ func (cli *mockDockerCli) ContainerLogs(ctx context.Context, container string, o if readLogs == true { return io.NopCloser(strings.NewReader("")), nil } + readLogs = true data := []string{"docker\n", "test\n", "1234\n"} ret := "" + for _, line := range data { startLineByte := make([]byte, 8) binary.LittleEndian.PutUint32(startLineByte, 1) //stdout stream binary.BigEndian.PutUint32(startLineByte[4:], uint32(len(line))) ret += fmt.Sprintf("%s%s", startLineByte, line) } + r := io.NopCloser(strings.NewReader(ret)) // r type is io.ReadCloser + return r, nil } @@ -252,6 +268,7 @@ func (cli *mockDockerCli) ContainerInspect(ctx context.Context, c string) (docke Tty: false, }, } + return r, nil } @@ -285,8 +302,11 @@ func TestOneShot(t *testing.T) { } for _, ts := range tests { - var subLogger *log.Entry - var logger *log.Logger + var ( + subLogger *log.Entry + logger *log.Logger + ) + if ts.expectedOutput != "" { logger.SetLevel(ts.logLevel) subLogger = logger.WithFields(log.Fields{ @@ -307,6 +327,7 @@ func TestOneShot(t *testing.T) { if err := dockerClient.ConfigureByDSN(ts.dsn, labels, subLogger, ""); err != nil { t.Fatalf("unable to configure dsn '%s': %s", ts.dsn, err) } + dockerClient.Client = new(mockDockerCli) out := make(chan types.Event, 100) tomb := tomb.Tomb{} @@ -315,8 +336,7 @@ func TestOneShot(t *testing.T) { // else we do the check before actualLines is incremented ... if ts.expectedLines != 0 { - assert.Equal(t, ts.expectedLines, len(out)) + assert.Len(t, out, ts.expectedLines) } } - } diff --git a/pkg/acquisition/modules/journalctl/journalctl_test.go b/pkg/acquisition/modules/journalctl/journalctl_test.go index 0ad49edd8..a91fba31b 100644 --- a/pkg/acquisition/modules/journalctl/journalctl_test.go +++ b/pkg/acquisition/modules/journalctl/journalctl_test.go @@ -21,6 +21,7 @@ func TestBadConfiguration(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -48,6 +49,7 @@ journalctl_filter: subLogger := log.WithFields(log.Fields{ "type": "journalctl", }) + for _, test := range tests { f := JournalCtlSource{} err := f.Configure([]byte(test.config), subLogger) @@ -59,6 +61,7 @@ func TestConfigureDSN(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { dsn string expectedErr string @@ -92,9 +95,11 @@ func TestConfigureDSN(t *testing.T) { expectedErr: "", }, } + subLogger := log.WithFields(log.Fields{ "type": "journalctl", }) + for _, test := range tests { f := JournalCtlSource{} err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "") @@ -106,6 +111,7 @@ func TestOneShot(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -137,9 +143,12 @@ journalctl_filter: }, } for _, ts := range tests { - var logger *log.Logger - var subLogger *log.Entry - var hook *test.Hook + var ( + logger *log.Logger + subLogger *log.Entry + hook *test.Hook + ) + if ts.expectedOutput != "" { logger, hook = test.NewNullLogger() logger.SetLevel(ts.logLevel) @@ -151,27 +160,32 @@ journalctl_filter: "type": "journalctl", }) } + tomb := tomb.Tomb{} out := make(chan types.Event, 100) j := JournalCtlSource{} + err := j.Configure([]byte(ts.config), subLogger) if err != nil { t.Fatalf("Unexpected error : %s", err) } + err = j.OneShotAcquisition(out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) + if err != nil { continue } if ts.expectedLines != 0 { - assert.Equal(t, ts.expectedLines, len(out)) + assert.Len(t, out, ts.expectedLines) } if ts.expectedOutput != "" { if hook.LastEntry() == nil { t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput) } + assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) hook.Reset() } @@ -182,6 +196,7 @@ func TestStreaming(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("Skipping test on windows") } + tests := []struct { config string expectedErr string @@ -202,9 +217,12 @@ journalctl_filter: }, } for _, ts := range tests { - var logger *log.Logger - var subLogger *log.Entry - var hook *test.Hook + var ( + logger *log.Logger + subLogger *log.Entry + hook *test.Hook + ) + if ts.expectedOutput != "" { logger, hook = test.NewNullLogger() logger.SetLevel(ts.logLevel) @@ -216,14 +234,18 @@ journalctl_filter: "type": "journalctl", }) } + tomb := tomb.Tomb{} out := make(chan types.Event) j := JournalCtlSource{} + err := j.Configure([]byte(ts.config), subLogger) if err != nil { t.Fatalf("Unexpected error : %s", err) } + actualLines := 0 + if ts.expectedLines != 0 { go func() { READLOOP: @@ -240,6 +262,7 @@ journalctl_filter: err = j.StreamingAcquisition(out, &tomb) cstest.AssertErrorContains(t, err, ts.expectedErr) + if err != nil { continue } @@ -248,16 +271,20 @@ journalctl_filter: time.Sleep(1 * time.Second) assert.Equal(t, ts.expectedLines, actualLines) } + tomb.Kill(nil) tomb.Wait() + output, _ := exec.Command("pgrep", "-x", "journalctl").CombinedOutput() if string(output) != "" { t.Fatalf("Found a journalctl process after killing the tomb !") } + if ts.expectedOutput != "" { if hook.LastEntry() == nil { t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput) } + assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput) hook.Reset() } @@ -270,5 +297,6 @@ func TestMain(m *testing.M) { fullPath := filepath.Join(currentDir, "test_files") os.Setenv("PATH", fullPath+":"+os.Getenv("PATH")) } + os.Exit(m.Run()) } diff --git a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go index 799868dc8..c3502c956 100644 --- a/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go +++ b/pkg/acquisition/modules/kubernetesaudit/k8s_audit_test.go @@ -9,6 +9,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/tomb.v2" ) @@ -78,24 +79,23 @@ webhook_path: /k8s-audit`, err := f.UnmarshalConfig([]byte(test.config)) - assert.NoError(t, err) + require.NoError(t, err) err = f.Configure([]byte(test.config), subLogger) - assert.NoError(t, err) + require.NoError(t, err) f.StreamingAcquisition(out, tb) time.Sleep(1 * time.Second) tb.Kill(nil) err = tb.Wait() if test.expectedErr != "" { - assert.ErrorContains(t, err, test.expectedErr) + require.ErrorContains(t, err, test.expectedErr) return } - assert.NoError(t, err) + require.NoError(t, err) }) } - } func TestHandler(t *testing.T) { @@ -252,10 +252,10 @@ webhook_path: /k8s-audit`, f := KubernetesAuditSource{} err := f.UnmarshalConfig([]byte(test.config)) - assert.NoError(t, err) + require.NoError(t, err) err = f.Configure([]byte(test.config), subLogger) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body)) w := httptest.NewRecorder() @@ -268,11 +268,11 @@ webhook_path: /k8s-audit`, assert.Equal(t, test.expectedStatusCode, res.StatusCode) //time.Sleep(1 * time.Second) - assert.NoError(t, err) + require.NoError(t, err) tb.Kill(nil) err = tb.Wait() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.eventCount, eventCount) }) diff --git a/pkg/acquisition/modules/wineventlog/wineventlog_test.go b/pkg/acquisition/modules/wineventlog/wineventlog_test.go index 20f8a5834..053ba88b5 100644 --- a/pkg/acquisition/modules/wineventlog/wineventlog_test.go +++ b/pkg/acquisition/modules/wineventlog/wineventlog_test.go @@ -11,6 +11,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/sys/windows/svc/eventlog" "gopkg.in/tomb.v2" ) @@ -124,7 +125,7 @@ event_level: bla`, } assert.Contains(t, err.Error(), test.expectedErr) } else { - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectedQuery, q) } } @@ -221,9 +222,8 @@ event_ids: } } if test.expectedLines == nil { - assert.Equal(t, 0, len(linesRead)) + assert.Empty(t, linesRead) } else { - assert.Equal(t, len(test.expectedLines), len(linesRead)) assert.Equal(t, test.expectedLines, linesRead) } to.Kill(nil) diff --git a/pkg/alertcontext/alertcontext_test.go b/pkg/alertcontext/alertcontext_test.go index 2e7e71bd6..8b598eab8 100644 --- a/pkg/alertcontext/alertcontext_test.go +++ b/pkg/alertcontext/alertcontext_test.go @@ -7,6 +7,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewAlertContext(t *testing.T) { @@ -29,8 +30,7 @@ func TestNewAlertContext(t *testing.T) { for _, test := range tests { fmt.Printf("Running test '%s'\n", test.name) err := NewAlertContext(test.contextToSend, test.valueLength) - assert.ErrorIs(t, err, test.expectedErr) - + require.ErrorIs(t, err, test.expectedErr) } } @@ -193,7 +193,7 @@ func TestEventToContext(t *testing.T) { for _, test := range tests { fmt.Printf("Running test '%s'\n", test.name) err := NewAlertContext(test.contextToSend, test.valueLength) - assert.ErrorIs(t, err, nil) + require.NoError(t, err) metas, _ := EventToContext(test.events) assert.ElementsMatch(t, test.expectedResult, metas) diff --git a/pkg/alertcontext/config.go b/pkg/alertcontext/config.go index 160804487..74ca1523a 100644 --- a/pkg/alertcontext/config.go +++ b/pkg/alertcontext/config.go @@ -2,7 +2,9 @@ package alertcontext import ( "encoding/json" + "errors" "fmt" + "io/fs" "os" "path/filepath" "slices" @@ -14,6 +16,8 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/cwhub" ) +var ErrNoContextData = errors.New("no context to send") + // this file is here to avoid circular dependencies between the configuration and the hub // HubItemWrapper is a wrapper around a hub item to unmarshal only the context part @@ -25,7 +29,7 @@ type HubItemWrapper struct { // mergeContext adds the context from src to dest. func mergeContext(dest map[string][]string, src map[string][]string) error { if len(src) == 0 { - return fmt.Errorf("no context data to merge") + return ErrNoContextData } for k, v := range src { @@ -86,8 +90,9 @@ func addContextFromFile(toSend map[string][]string, filePath string) error { } err = mergeContext(toSend, newContext) - if err != nil { - log.Warningf("while merging context from %s: %s", filePath, err) + if err != nil && !errors.Is(err, ErrNoContextData) { + // having an empty console/context.yaml is not an error + return err } return nil @@ -125,8 +130,10 @@ func LoadConsoleContext(c *csconfig.Config, hub *cwhub.Hub) error { } if err := addContextFromFile(c.Crowdsec.ContextToSend, c.Crowdsec.ConsoleContextPath); err != nil { - if !ignoreMissing || !os.IsNotExist(err) { + if !errors.Is(err, fs.ErrNotExist) { return err + } else if !ignoreMissing { + log.Warningf("while merging context from %s: %s", c.Crowdsec.ConsoleContextPath, err) } } diff --git a/pkg/apiclient/alerts_service_test.go b/pkg/apiclient/alerts_service_test.go index fcc9bd06a..31a947556 100644 --- a/pkg/apiclient/alerts_service_test.go +++ b/pkg/apiclient/alerts_service_test.go @@ -5,13 +5,14 @@ import ( "fmt" "net/http" "net/url" - "reflect" "testing" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/version" "github.com/crowdsecurity/crowdsec/pkg/models" @@ -25,12 +26,11 @@ func TestAlertsListAsMachine(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -39,19 +39,16 @@ func TestAlertsListAsMachine(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { - if r.URL.RawQuery == "ip=1.2.3.4" { testMethod(t, r, "GET") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `null`) + return } @@ -107,36 +104,26 @@ func TestAlertsListAsMachine(t *testing.T) { ]`) }) - tcapacity := int32(5) - tduration := "59m49.264032632s" - torigin := "crowdsec" tscenario := "crowdsecurity/ssh-bf" tscope := "Ip" - ttype := "ban" tvalue := "1.1.1.172" ttimestamp := "2020-11-28 10:20:46 +0000 UTC" - teventscount := int32(6) - tleakspeed := "10s" tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761" - tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f" - tscenarioversion := "0.1" - tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100" - tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100" expected := models.GetAlertsResponse{ &models.Alert{ - Capacity: &tcapacity, + Capacity: ptr.Of(int32(5)), CreatedAt: "2020-11-28T10:20:47+01:00", Decisions: []*models.Decision{ { - Duration: &tduration, + Duration: ptr.Of("59m49.264032632s"), ID: 1, - Origin: &torigin, + Origin: ptr.Of("crowdsec"), Scenario: &tscenario, Scope: &tscope, - Simulated: new(bool), //false, - Type: &ttype, + Simulated: ptr.Of(false), + Type: ptr.Of("ban"), Value: &tvalue, }, }, @@ -167,16 +154,16 @@ func TestAlertsListAsMachine(t *testing.T) { Timestamp: &ttimestamp, }, }, - EventsCount: &teventscount, + EventsCount: ptr.Of(int32(6)), ID: 1, - Leakspeed: &tleakspeed, + Leakspeed: ptr.Of("10s"), MachineID: "test", Message: &tmessage, Remediation: false, Scenario: &tscenario, - ScenarioHash: &tscenariohash, - ScenarioVersion: &tscenarioversion, - Simulated: new(bool), //(false), + ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"), + ScenarioVersion: ptr.Of("0.1"), + Simulated: ptr.Of(false), Source: &models.Source{ AsName: "Cloudflare Inc", AsNumber: "", @@ -188,8 +175,8 @@ func TestAlertsListAsMachine(t *testing.T) { Scope: &tscope, Value: &tvalue, }, - StartAt: &tstartat, - StopAt: &tstopat, + StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"), + StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"), }, } @@ -198,30 +185,16 @@ func TestAlertsListAsMachine(t *testing.T) { //log.Debugf("expected : -> %s", spew.Sdump(expected)) //first one returns data alerts, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{}) - if err != nil { - log.Errorf("test Unable to list alerts : %+v", err) - } + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, expected, *alerts) - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if !reflect.DeepEqual(*alerts, expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } //this one doesn't - filter := AlertsListOpts{IPEquals: new(string)} - *filter.IPEquals = "1.2.3.4" + filter := AlertsListOpts{IPEquals: ptr.Of("1.2.3.4")} alerts, resp, err = client.Alerts.List(context.Background(), filter) - if err != nil { - log.Errorf("test Unable to list alerts : %+v", err) - } - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Empty(t, *alerts) } @@ -236,9 +209,7 @@ func TestAlertsGetAsMachine(t *testing.T) { log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -247,12 +218,10 @@ func TestAlertsGetAsMachine(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() + mux.HandleFunc("/alerts/2", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") w.WriteHeader(http.StatusNotFound) @@ -312,34 +281,24 @@ func TestAlertsGetAsMachine(t *testing.T) { }`) }) - tcapacity := int32(5) - tduration := "59m49.264032632s" - torigin := "crowdsec" tscenario := "crowdsecurity/ssh-bf" tscope := "Ip" ttype := "ban" tvalue := "1.1.1.172" ttimestamp := "2020-11-28 10:20:46 +0000 UTC" - teventscount := int32(6) - tleakspeed := "10s" - tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761" - tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f" - tscenarioversion := "0.1" - tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100" - tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100" expected := &models.Alert{ - Capacity: &tcapacity, + Capacity: ptr.Of(int32(5)), CreatedAt: "2020-11-28T10:20:47+01:00", Decisions: []*models.Decision{ { - Duration: &tduration, + Duration: ptr.Of("59m49.264032632s"), ID: 1, - Origin: &torigin, + Origin: ptr.Of("crowdsec"), Scenario: &tscenario, Scope: &tscope, - Simulated: new(bool), //false, + Simulated: ptr.Of(false), Type: &ttype, Value: &tvalue, }, @@ -371,16 +330,16 @@ func TestAlertsGetAsMachine(t *testing.T) { Timestamp: &ttimestamp, }, }, - EventsCount: &teventscount, + EventsCount: ptr.Of(int32(6)), ID: 1, - Leakspeed: &tleakspeed, + Leakspeed: ptr.Of("10s"), MachineID: "test", - Message: &tmessage, + Message: ptr.Of("Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761"), Remediation: false, Scenario: &tscenario, - ScenarioHash: &tscenariohash, - ScenarioVersion: &tscenarioversion, - Simulated: new(bool), //(false), + ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"), + ScenarioVersion: ptr.Of("0.1"), + Simulated: ptr.Of(false), Source: &models.Source{ AsName: "Cloudflare Inc", AsNumber: "", @@ -392,24 +351,18 @@ func TestAlertsGetAsMachine(t *testing.T) { Scope: &tscope, Value: &tvalue, }, - StartAt: &tstartat, - StopAt: &tstopat, + StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"), + StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"), } alerts, resp, err := client.Alerts.GetByID(context.Background(), 1) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if !reflect.DeepEqual(*alerts, *expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *alerts) //fail _, _, err = client.Alerts.GetByID(context.Background(), 2) - assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found") + cstest.RequireErrorMessage(t, err, "API error: object not found") } func TestAlertsCreateAsMachine(t *testing.T) { @@ -420,17 +373,17 @@ func TestAlertsCreateAsMachine(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") w.WriteHeader(http.StatusOK) w.Write([]byte(`["3"]`)) }) + log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -439,10 +392,7 @@ func TestAlertsCreateAsMachine(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() @@ -452,13 +402,8 @@ func TestAlertsCreateAsMachine(t *testing.T) { expected := &models.AddAlertsResponse{"3"} - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if !reflect.DeepEqual(*alerts, *expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *alerts) } func TestAlertsDeleteAsMachine(t *testing.T) { @@ -469,18 +414,18 @@ func TestAlertsDeleteAsMachine(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) w.Write([]byte(`{"message":"0 deleted alerts"}`)) }) + log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - log.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -489,25 +434,16 @@ func TestAlertsDeleteAsMachine(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - log.Fatalf("new api client: %s", err) - } + require.NoError(t, err) defer teardown() - alert := AlertsDeleteOpts{IPEquals: new(string)} - *alert.IPEquals = "1.2.3.4" + alert := AlertsDeleteOpts{IPEquals: ptr.Of("1.2.3.4")} alerts, resp, err := client.Alerts.Delete(context.Background(), alert) require.NoError(t, err) expected := &models.DeleteAlertsResponse{NbDeleted: ""} - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if !reflect.DeepEqual(*alerts, *expected) { - t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *alerts) } diff --git a/pkg/apiclient/auth.go b/pkg/apiclient/auth.go index 86cdc7736..163e96718 100644 --- a/pkg/apiclient/auth.go +++ b/pkg/apiclient/auth.go @@ -3,6 +3,7 @@ package apiclient import ( "bytes" "encoding/json" + "errors" "fmt" "io" "math/rand" @@ -13,7 +14,6 @@ import ( "time" "github.com/go-openapi/strfmt" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/fflag" @@ -52,10 +52,12 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) { dump, _ := httputil.DumpRequest(req, true) log.Tracef("auth-api request: %s", string(dump)) } + // Make the HTTP request. resp, err := t.transport().RoundTrip(req) if err != nil { log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err) + return resp, err } @@ -115,10 +117,12 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) for i := 0; i < maxAttempts; i++ { if i > 0 { if r.withBackOff { + //nolint:gosec backoff += 10 + rand.Intn(20) } log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts) + select { case <-req.Context().Done(): return resp, req.Context().Err() @@ -134,8 +138,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) resp, err = r.next.RoundTrip(clonedReq) if err != nil { - left := maxAttempts - i - 1 - if left > 0 { + if left := maxAttempts - i - 1; left > 0 { log.Errorf("error while performing request: %s; %d retries left", err, left) } @@ -177,7 +180,7 @@ func (t *JWTTransport) refreshJwtToken() error { log.Debugf("scenarios list updated for '%s'", *t.MachineID) } - var auth = models.WatcherAuthRequest{ + auth := models.WatcherAuthRequest{ MachineID: t.MachineID, Password: t.Password, Scenarios: t.Scenarios, @@ -264,13 +267,14 @@ func (t *JWTTransport) refreshJwtToken() error { // RoundTrip implements the RoundTripper interface. func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI + // In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI // we use a mutex to avoid this - //We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) + // We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) t.refreshTokenMutex.Lock() if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) { if err := t.refreshJwtToken(); err != nil { t.refreshTokenMutex.Unlock() + return nil, err } } @@ -296,8 +300,9 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { } if err != nil { - /*we had an error (network error for example, or 401 because token is refused), reset the token ?*/ + // we had an error (network error for example, or 401 because token is refused), reset the token ? t.Token = "" + return resp, fmt.Errorf("performing jwt auth: %w", err) } @@ -355,6 +360,7 @@ func cloneRequest(r *http.Request) *http.Request { *r2 = *r // deep copy of the Header r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { r2.Header[k] = append([]string(nil), s...) } diff --git a/pkg/apiclient/auth_service_test.go b/pkg/apiclient/auth_service_test.go index b56d52868..f5de827a1 100644 --- a/pkg/apiclient/auth_service_test.go +++ b/pkg/apiclient/auth_service_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/version" @@ -24,13 +25,11 @@ type BasicMockPayload struct { } func getLoginsForMockErrorCases() map[string]int { - loginsForMockErrorCases := map[string]int{ + return map[string]int{ "login_400": http.StatusBadRequest, "login_409": http.StatusConflict, "login_500": http.StatusInternalServerError, } - - return loginsForMockErrorCases } func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { @@ -49,7 +48,7 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { w.WriteHeader(http.StatusBadRequest) } - responseBody := "" + var responseBody string responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID] if !hasFoundErrorMock { @@ -58,6 +57,7 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) { } else { responseBody = fmt.Sprintf("Error %d", responseCode) } + log.Printf("MockServerReceived > %s // Login : [%s] => Mux response [%d]", newStr, payload.MachineID, responseCode) w.WriteHeader(responseCode) @@ -76,14 +76,13 @@ func TestWatcherRegister(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + //body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password} initBasicMuxMock(t, mux, "/watchers") log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) // Valid Registration : should retrieve the client and no err clientconfig := Config{ @@ -95,9 +94,7 @@ func TestWatcherRegister(t *testing.T) { } client, err := RegisterClient(&clientconfig, &http.Client{}) - if client == nil || err != nil { - t.Fatalf("while registering client : %s", err) - } + require.NoError(t, err) log.Printf("->%T", client) @@ -107,11 +104,8 @@ func TestWatcherRegister(t *testing.T) { clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) client, err = RegisterClient(&clientconfig, &http.Client{}) - if client != nil || err == nil { - t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest) - } else { - log.Printf("The RegisterClient function handled the error code %d as expected \n\r", errorCodeToTest) - } + require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest) + require.Error(t, err, "error expected for the response code %d", errorCodeToTest) } } @@ -126,9 +120,7 @@ func TestWatcherAuth(t *testing.T) { log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok auth clientConfig := &Config{ @@ -139,34 +131,27 @@ func TestWatcherAuth(t *testing.T) { VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } - client, err := NewClient(clientConfig) - if err != nil { - t.Fatalf("new api client: %s", err) - } + client, err := NewClient(clientConfig) + require.NoError(t, err) _, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &clientConfig.MachineID, Password: &clientConfig.Password, Scenarios: clientConfig.Scenarios, }) - if err != nil { - t.Fatalf("unexpect auth err 0: %s", err) - } + require.NoError(t, err) // Testing error handling on AuthenticateWatcher (400, 409): should retrieve an error // Not testing 500 because it loops and try to re-autehnticate. But you can test it manually by adding it in array errorCodesToTest := [2]int{http.StatusBadRequest, http.StatusConflict} for _, errorCodeToTest := range errorCodesToTest { clientConfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest) + client, err := NewClient(clientConfig) + require.NoError(t, err) - if err != nil { - t.Fatalf("new api client: %s", err) - } - - var resp *Response - _, resp, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + _, resp, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &clientConfig.MachineID, Password: &clientConfig.Password, }) @@ -175,9 +160,7 @@ func TestWatcherAuth(t *testing.T) { resp.Response.Body.Close() bodyBytes, err := io.ReadAll(resp.Response.Body) - if err != nil { - t.Fatalf("error while reading body: %s", err.Error()) - } + require.NoError(t, err) log.Printf(string(bodyBytes)) t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest) @@ -199,10 +182,12 @@ func TestWatcherUnregister(t *testing.T) { assert.Equal(t, int64(0), r.ContentLength) w.WriteHeader(http.StatusOK) }) + mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") buf := new(bytes.Buffer) _, _ = buf.ReadFrom(r.Body) + newStr := buf.String() if newStr == `{"machine_id":"test_login","password":"test_password","scenarios":["crowdsecurity/test"]} ` { @@ -217,9 +202,7 @@ func TestWatcherUnregister(t *testing.T) { log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) mycfg := &Config{ MachineID: "test_login", @@ -229,16 +212,12 @@ func TestWatcherUnregister(t *testing.T) { VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } - client, err := NewClient(mycfg) - if err != nil { - t.Fatalf("new api client: %s", err) - } + client, err := NewClient(mycfg) + require.NoError(t, err) _, err = client.Auth.UnregisterWatcher(context.Background()) - if err != nil { - t.Fatalf("while registering client : %s", err) - } + require.NoError(t, err) log.Printf("->%T", client) } @@ -255,6 +234,7 @@ func TestWatcherEnroll(t *testing.T) { _, _ = buf.ReadFrom(r.Body) newStr := buf.String() log.Debugf("body -> %s", newStr) + if newStr == `{"attachment_key":"goodkey","name":"","tags":[],"overwrite":false} ` { log.Print("good key") @@ -266,17 +246,17 @@ func TestWatcherEnroll(t *testing.T) { fmt.Fprintf(w, `{"message":"the attachment key provided is not valid"}`) } }) + mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "POST") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`) }) + log.Printf("URL is %s", urlx) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) mycfg := &Config{ MachineID: "test_login", @@ -286,16 +266,12 @@ func TestWatcherEnroll(t *testing.T) { VersionPrefix: "v1", Scenarios: []string{"crowdsecurity/test"}, } - client, err := NewClient(mycfg) - if err != nil { - t.Fatalf("new api client: %s", err) - } + client, err := NewClient(mycfg) + require.NoError(t, err) _, err = client.Auth.EnrollWatcher(context.Background(), "goodkey", "", []string{}, false) - if err != nil { - t.Fatalf("unexpect enroll err: %s", err) - } + require.NoError(t, err) _, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false) assert.Contains(t, err.Error(), "the attachment key provided is not valid", "got %s", err.Error()) diff --git a/pkg/apiclient/auth_test.go b/pkg/apiclient/auth_test.go index 7e7377a43..f686de622 100644 --- a/pkg/apiclient/auth_test.go +++ b/pkg/apiclient/auth_test.go @@ -9,6 +9,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" ) func TestApiAuth(t *testing.T) { @@ -17,6 +20,7 @@ func TestApiAuth(t *testing.T) { mux, urlx, teardown := setup() mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "GET") + if r.Header.Get("X-Api-Key") == "ixu" { assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) w.WriteHeader(http.StatusOK) @@ -26,11 +30,11 @@ func TestApiAuth(t *testing.T) { w.Write([]byte(`{"message":"access forbidden"}`)) } }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) defer teardown() @@ -40,18 +44,12 @@ func TestApiAuth(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } - - alert := DecisionsListOpts{IPEquals: new(string)} - *alert.IPEquals = "1.2.3.4" - _, resp, err := newcli.Decisions.List(context.Background(), alert) require.NoError(t, err) - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + alert := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")} + _, resp, err := newcli.Decisions.List(context.Background(), alert) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) //ko bad token auth = &APIKeyTransport{ @@ -59,25 +57,21 @@ func TestApiAuth(t *testing.T) { } newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) _, resp, err = newcli.Decisions.List(context.Background(), alert) log.Infof("--> %s", err) - if resp.Response.StatusCode != http.StatusForbidden { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + assert.Equal(t, http.StatusForbidden, resp.Response.StatusCode) + + cstest.RequireErrorMessage(t, err, "API error: access forbidden") - assert.Contains(t, err.Error(), "API error: access forbidden") //ko empty token auth = &APIKeyTransport{} + newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) _, _, err = newcli.Decisions.List(context.Background(), alert) require.Error(t, err) diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index 75bc52881..b183a8c79 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -74,6 +74,7 @@ func NewClient(config *Config) (*ApiClient, error) { VersionPrefix: config.VersionPrefix, UpdateScenario: config.UpdateScenario, } + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig.RootCAs = CaCertPool @@ -180,8 +181,7 @@ func (e *ErrorResponse) Error() string { } func newResponse(r *http.Response) *Response { - response := &Response{Response: r} - return response + return &Response{Response: r} } func CheckResponse(r *http.Response) error { @@ -192,7 +192,7 @@ func CheckResponse(r *http.Response) error { errorResponse := &ErrorResponse{} data, err := io.ReadAll(r.Body) - if err == nil && data != nil { + if err == nil && len(data)>0 { err := json.Unmarshal(data, errorResponse) if err != nil { return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err) diff --git a/pkg/apiclient/client_http_test.go b/pkg/apiclient/client_http_test.go index fa25ee171..a7582eaf4 100644 --- a/pkg/apiclient/client_http_test.go +++ b/pkg/apiclient/client_http_test.go @@ -8,19 +8,19 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/version" ) func TestNewRequestInvalid(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + //missing slash in uri apiURL, err := url.Parse(urlx) - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -29,9 +29,8 @@ func TestNewRequestInvalid(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) @@ -44,17 +43,16 @@ func TestNewRequestInvalid(t *testing.T) { }) _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) - assert.Contains(t, err.Error(), `building request: BaseURL must have a trailing slash, but `) + cstest.RequireErrorContains(t, err, "building request: BaseURL must have a trailing slash, but ") } func TestNewRequestTimeout(t *testing.T) { mux, urlx, teardown := setup() defer teardown() - //missing slash in uri + + // missing slash in uri apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) client, err := NewClient(&Config{ MachineID: "test_login", @@ -63,9 +61,8 @@ func TestNewRequestTimeout(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { time.Sleep(2 * time.Second) @@ -75,5 +72,5 @@ func TestNewRequestTimeout(t *testing.T) { defer cancel() _, _, err = client.Alerts.List(ctx, AlertsListOpts{}) - assert.Contains(t, err.Error(), `performing request: context deadline exceeded`) + cstest.RequireErrorMessage(t, err, "performing request: context deadline exceeded") } diff --git a/pkg/apiclient/client_test.go b/pkg/apiclient/client_test.go index a75b3dd41..dc6eae169 100644 --- a/pkg/apiclient/client_test.go +++ b/pkg/apiclient/client_test.go @@ -11,7 +11,9 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/version" ) @@ -20,13 +22,13 @@ import ( - each test will then bind handler for the method(s) they want to try */ -func setup() (mux *http.ServeMux, serverURL string, teardown func()) { +func setup() (*http.ServeMux, string, func()) { return setupWithPrefix("v1") } -func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) { +func setupWithPrefix(urlPrefix string) (*http.ServeMux, string, func()) { // mux is the HTTP request multiplexer used with the test server. - mux = http.NewServeMux() + mux := http.NewServeMux() baseURLPath := "/" + urlPrefix apiHandler := http.NewServeMux() @@ -40,19 +42,16 @@ func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, te func testMethod(t *testing.T, r *http.Request, want string) { t.Helper() - - if got := r.Method; got != want { - t.Errorf("Request method: %v, want %v", got, want) - } + assert.Equal(t, want, r.Method) } func TestNewClientOk(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -60,9 +59,8 @@ func TestNewClientOk(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -75,22 +73,17 @@ func TestNewClientOk(t *testing.T) { }) _, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{}) - if err != nil { - t.Fatalf("test Unable to list alerts : %+v", err) - } - - if resp.Response.StatusCode != http.StatusOK { - t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated) - } + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) } func TestNewClientKo(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -98,9 +91,8 @@ func TestNewClientKo(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) + /*mock login*/ mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) @@ -113,36 +105,36 @@ func TestNewClientKo(t *testing.T) { }) _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) - assert.Contains(t, err.Error(), `API error: bad login/password`) + cstest.RequireErrorContains(t, err, `API error: bad login/password`) + log.Printf("err-> %s", err) } func TestNewDefaultClient(t *testing.T) { mux, urlx, teardown := setup() defer teardown() + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewDefaultClient(apiURL, "/v1", "", nil) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`{"code": 401, "message" : "brr"}`)) }) + _, _, err = client.Alerts.List(context.Background(), AlertsListOpts{}) - assert.Contains(t, err.Error(), `performing request: API error: brr`) + cstest.RequireErrorMessage(t, err, "performing request: API error: brr") + log.Printf("err-> %s", err) } func TestNewClientRegisterKO(t *testing.T) { apiURL, err := url.Parse("http://127.0.0.1:4242/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + _, err = RegisterClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -150,17 +142,18 @@ func TestNewClientRegisterKO(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) + if runtime.GOOS != "windows" { - assert.Contains(t, fmt.Sprintf("%s", err), "dial tcp 127.0.0.1:4242: connect: connection refused") + cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused") } else { - assert.Contains(t, fmt.Sprintf("%s", err), " No connection could be made because the target machine actively refused it.") + cstest.RequireErrorContains(t, err, " No connection could be made because the target machine actively refused it.") } } func TestNewClientRegisterOK(t *testing.T) { log.SetLevel(log.TraceLevel) - mux, urlx, teardown := setup() + mux, urlx, teardown := setup() defer teardown() /*mock login*/ @@ -171,9 +164,8 @@ func TestNewClientRegisterOK(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := RegisterClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -181,17 +173,15 @@ func TestNewClientRegisterOK(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) - if err != nil { - t.Fatalf("while registering client : %s", err) - } + require.NoError(t, err) log.Printf("->%T", client) } func TestNewClientBadAnswer(t *testing.T) { log.SetLevel(log.TraceLevel) - mux, urlx, teardown := setup() + mux, urlx, teardown := setup() defer teardown() /*mock login*/ @@ -200,10 +190,10 @@ func TestNewClientBadAnswer(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte(`bad`)) }) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + _, err = RegisterClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -211,5 +201,5 @@ func TestNewClientBadAnswer(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }, &http.Client{}) - assert.Contains(t, fmt.Sprintf("%s", err), `invalid body: invalid character 'b' looking for beginning of value`) + cstest.RequireErrorContains(t, err, "invalid body: invalid character 'b' looking for beginning of value") } diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index 89e6eff92..a3f02c0ef 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -183,7 +183,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl req = req.WithContext(ctx) log.Debugf("[URL] %s %s", req.Method, req.URL) - // we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc + // we don't use client_http Do method because we need the reader and is not provided. + // We would be forced to use Pipe and goroutine, etc resp, err := client.Do(req) if resp != nil && resp.Body != nil { defer resp.Body.Close() @@ -216,6 +217,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl if resp.StatusCode != http.StatusOK { log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL) + return nil, false, nil } diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index e9954d9a1..fb2fb7342 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -5,13 +5,13 @@ import ( "fmt" "net/http" "net/url" - "reflect" "testing" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/version" @@ -38,10 +38,9 @@ func TestDecisionsList(t *testing.T) { //no results } }) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -49,55 +48,32 @@ func TestDecisionsList(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" - torigin := "cscli" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" - tscope := "Ip" - ttype := "ban" - tvalue := "1.2.3.4" expected := &models.GetDecisionsResponse{ &models.Decision{ - Duration: &tduration, + Duration: ptr.Of("3h59m55.756182786s"), ID: 4, - Origin: &torigin, - Scenario: &tscenario, - Scope: &tscope, - Type: &ttype, - Value: &tvalue, + Origin: ptr.Of("cscli"), + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), + Scope: ptr.Of("Ip"), + Type: ptr.Of("ban"), + Value: ptr.Of("1.2.3.4"), }, } - //OK decisions - decisionsFilter := DecisionsListOpts{IPEquals: new(string)} - *decisionsFilter.IPEquals = "1.2.3.4" + // OK decisions + decisionsFilter := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")} decisions, resp, err := newcli.Decisions.List(context.Background(), decisionsFilter) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) //Empty return - decisionsFilter = DecisionsListOpts{IPEquals: new(string)} - *decisionsFilter.IPEquals = "1.2.3.5" + decisionsFilter = DecisionsListOpts{IPEquals: ptr.Of("1.2.3.5")} decisions, resp, err = newcli.Decisions.List(context.Background(), decisionsFilter) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Empty(t, *decisions) } @@ -120,6 +96,7 @@ func TestDecisionsStream(t *testing.T) { } } }) + mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "ixu", r.Header.Get("X-Api-Key")) testMethod(t, r, http.MethodDelete) @@ -129,9 +106,7 @@ func TestDecisionsStream(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -139,63 +114,38 @@ func TestDecisionsStream(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" - torigin := "cscli" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" - tscope := "Ip" - ttype := "ban" - tvalue := "1.2.3.4" expected := &models.DecisionsStreamResponse{ New: models.GetDecisionsResponse{ &models.Decision{ - Duration: &tduration, + Duration: ptr.Of("3h59m55.756182786s"), ID: 4, - Origin: &torigin, - Scenario: &tscenario, - Scope: &tscope, - Type: &ttype, - Value: &tvalue, + Origin: ptr.Of("cscli"), + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), + Scope: ptr.Of("Ip"), + Type: ptr.Of("ban"), + Value: ptr.Of("1.2.3.4"), }, }, } decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) //and second call, we get empty lists decisions, resp, err = newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: false}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) assert.Empty(t, decisions.New) assert.Empty(t, decisions.Deleted) //delete stream resp, err = newcli.Decisions.StopStream(context.Background()) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) } func TestDecisionsStreamV3Compatibility(t *testing.T) { @@ -219,9 +169,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -229,38 +177,30 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" torigin := "CAPI" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" tscope := "ip" ttype := "ban" - tvalue := "1.2.3.4" - tvalue1 := "1.2.3.5" - tscenarioDeleted := "deleted" - tdurationDeleted := "1h" expected := &models.DecisionsStreamResponse{ New: models.GetDecisionsResponse{ &models.Decision{ - Duration: &tduration, + Duration: ptr.Of("3h59m55.756182786s"), Origin: &torigin, - Scenario: &tscenario, + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), Scope: &tscope, Type: &ttype, - Value: &tvalue, + Value: ptr.Of("1.2.3.4"), }, }, Deleted: models.GetDecisionsResponse{ &models.Decision{ - Duration: &tdurationDeleted, + Duration: ptr.Of("1h"), Origin: &torigin, - Scenario: &tscenarioDeleted, + Scenario: ptr.Of("deleted"), Scope: &tscope, Type: &ttype, - Value: &tvalue1, + Value: ptr.Of("1.2.3.5"), }, }, } @@ -268,18 +208,8 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) { // GetStream is supposed to consume v3 payload and return v2 response decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) } func TestDecisionsStreamV3(t *testing.T) { @@ -300,9 +230,7 @@ func TestDecisionsStreamV3(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -310,30 +238,19 @@ func TestDecisionsStreamV3(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tduration := "3h59m55.756182786s" - tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'" tscope := "ip" - tvalue := "1.2.3.4" - tvalue1 := "1.2.3.5" - tdurationBlocklist := "24h" - tnameBlocklist := "blocklist1" - tremediationBlocklist := "ban" - tscopeBlocklist := "ip" - turlBlocklist := "/v3/blocklist" expected := &modelscapi.GetDecisionsStreamResponse{ New: modelscapi.GetDecisionsStreamResponseNew{ &modelscapi.GetDecisionsStreamResponseNewItem{ Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{ { - Duration: &tduration, - Value: &tvalue, + Duration: ptr.Of("3h59m55.756182786s"), + Value: ptr.Of("1.2.3.4"), }, }, - Scenario: &tscenario, + Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"), Scope: &tscope, }, }, @@ -341,18 +258,18 @@ func TestDecisionsStreamV3(t *testing.T) { &modelscapi.GetDecisionsStreamResponseDeletedItem{ Scope: &tscope, Decisions: []string{ - tvalue1, + "1.2.3.5", }, }, }, Links: &modelscapi.GetDecisionsStreamResponseLinks{ Blocklists: []*modelscapi.BlocklistLink{ { - Duration: &tdurationBlocklist, - Name: &tnameBlocklist, - Remediation: &tremediationBlocklist, - Scope: &tscopeBlocklist, - URL: &turlBlocklist, + Duration: ptr.Of("24h"), + Name: ptr.Of("blocklist1"), + Remediation: ptr.Of("ban"), + Scope: ptr.Of("ip"), + URL: ptr.Of("/v3/blocklist"), }, }, }, @@ -361,18 +278,8 @@ func TestDecisionsStreamV3(t *testing.T) { // GetStream is supposed to consume v3 payload and return v2 response decisions, resp, err := newcli.Decisions.GetStreamV3(context.Background(), DecisionsStreamOpts{Startup: true}) require.NoError(t, err) - - if resp.Response.StatusCode != http.StatusOK { - t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK) - } - - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(*decisions, *expected) { - t.Fatalf("returned %+v, want %+v", resp, expected) - } + assert.Equal(t, http.StatusOK, resp.Response.StatusCode) + assert.Equal(t, *expected, *decisions) } func TestDecisionsFromBlocklist(t *testing.T) { @@ -383,10 +290,13 @@ func TestDecisionsFromBlocklist(t *testing.T) { mux.HandleFunc("/blocklist", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, http.MethodGet) + if r.Header.Get("If-Modified-Since") == "Sun, 01 Jan 2023 01:01:01 GMT" { w.WriteHeader(http.StatusNotModified) + return } + if r.Method == http.MethodGet { w.WriteHeader(http.StatusOK) w.Write([]byte("1.2.3.4\r\n1.2.3.5")) @@ -394,9 +304,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { }) apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) //ok answer auth := &APIKeyTransport{ @@ -404,12 +312,8 @@ func TestDecisionsFromBlocklist(t *testing.T) { } newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client()) - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) - tvalue1 := "1.2.3.4" - tvalue2 := "1.2.3.5" tdurationBlocklist := "24h" tnameBlocklist := "blocklist1" tremediationBlocklist := "ban" @@ -419,7 +323,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { expected := []*models.Decision{ { Duration: &tdurationBlocklist, - Value: &tvalue1, + Value: ptr.Of("1.2.3.4"), Scenario: &tnameBlocklist, Scope: &tscopeBlocklist, Type: &tremediationBlocklist, @@ -427,7 +331,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { }, { Duration: &tdurationBlocklist, - Value: &tvalue2, + Value: ptr.Of("1.2.3.5"), Scenario: &tnameBlocklist, Scope: &tscopeBlocklist, Type: &tremediationBlocklist, @@ -450,13 +354,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { log.Infof("expected : %s, %s, %s, %s, %s", *expected[0].Value, *expected[0].Duration, *expected[0].Scenario, *expected[0].Scope, *expected[0].Type) log.Infof("decisions: %s, %s, %s, %s, %s", *decisions[1].Value, *decisions[1].Duration, *decisions[1].Scenario, *decisions[1].Scope, *decisions[1].Type) - if err != nil { - t.Fatalf("new api client: %s", err) - } - - if !reflect.DeepEqual(decisions, expected) { - t.Fatalf("returned %+v, want %+v", decisions, expected) - } + assert.Equal(t, expected, decisions) // test cache control _, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ @@ -466,8 +364,10 @@ func TestDecisionsFromBlocklist(t *testing.T) { Name: &tnameBlocklist, Duration: &tdurationBlocklist, }, ptr.Of("Sun, 01 Jan 2023 01:01:01 GMT")) + require.NoError(t, err) assert.False(t, isModified) + _, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{ URL: &turlBlocklist, Scope: &tscopeBlocklist, @@ -475,6 +375,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { Name: &tnameBlocklist, Duration: &tdurationBlocklist, }, ptr.Of("Mon, 02 Jan 2023 01:01:01 GMT")) + require.NoError(t, err) assert.True(t, isModified) } @@ -485,6 +386,7 @@ func TestDeleteDecisions(t *testing.T) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`)) }) + mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) { testMethod(t, r, "DELETE") assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery) @@ -492,11 +394,12 @@ func TestDeleteDecisions(t *testing.T) { w.Write([]byte(`{"nbDeleted":"1"}`)) //w.Write([]byte(`{"message":"0 deleted alerts"}`)) }) + log.Printf("URL is %s", urlx) + apiURL, err := url.Parse(urlx + "/") - if err != nil { - t.Fatalf("parsing api url: %s", apiURL) - } + require.NoError(t, err) + client, err := NewClient(&Config{ MachineID: "test_login", Password: "test_password", @@ -504,18 +407,13 @@ func TestDeleteDecisions(t *testing.T) { URL: apiURL, VersionPrefix: "v1", }) - - if err != nil { - t.Fatalf("new api client: %s", err) - } + require.NoError(t, err) filters := DecisionsDeleteOpts{IPEquals: new(string)} *filters.IPEquals = "1.2.3.4" - deleted, _, err := client.Decisions.Delete(context.Background(), filters) - if err != nil { - t.Fatalf("unexpected err : %s", err) - } + deleted, _, err := client.Decisions.Delete(context.Background(), filters) + require.NoError(t, err) assert.Equal(t, "1", deleted.NbDeleted) defer teardown() @@ -530,22 +428,23 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { ScenariosContaining string ScenariosNotContaining string } + tests := []struct { - name string - fields fields - want string - wantErr bool + name string + fields fields + expected string + expectedErr string }{ { - name: "no filter", - want: baseURLString + "?", + name: "no filter", + expected: baseURLString + "?", }, { name: "startup=true", fields: fields{ Startup: true, }, - want: baseURLString + "?startup=true", + expected: baseURLString + "?startup=true", }, { name: "set all params", @@ -555,7 +454,7 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { ScenariosContaining: "ssh", ScenariosNotContaining: "bf", }, - want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", + expected: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true", }, } @@ -568,25 +467,20 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) { ScenariosContaining: tt.fields.ScenariosContaining, ScenariosNotContaining: tt.fields.ScenariosNotContaining, } + got, err := o.addQueryParamsToURL(baseURLString) - if (err != nil) != tt.wantErr { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() error = %v, wantErr %v", err, tt.wantErr) + cstest.RequireErrorContains(t, err, tt.expectedErr) + if tt.expectedErr != "" { return } gotURL, err := url.Parse(got) - if err != nil { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err) - } + require.NoError(t, err) - expectedURL, err := url.Parse(tt.want) - if err != nil { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err) - } + expectedURL, err := url.Parse(tt.expected) + require.NoError(t, err) - if *gotURL != *expectedURL { - t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() = %v, want %v", *gotURL, *expectedURL) - } + assert.Equal(t, *expectedURL, *gotURL) }) } } diff --git a/pkg/apiclient/heartbeat.go b/pkg/apiclient/heartbeat.go index bf61b8d2e..77e0ecc2e 100644 --- a/pkg/apiclient/heartbeat.go +++ b/pkg/apiclient/heartbeat.go @@ -38,11 +38,13 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) { select { case <-hbTimer.C: log.Debug("heartbeat: sending heartbeat") + ok, resp, err := h.Ping(ctx) if err != nil { log.Errorf("heartbeat error : %s", err) continue } + resp.Response.Body.Close() if resp.Response.StatusCode != http.StatusOK { log.Errorf("heartbeat unexpected return code : %d", resp.Response.StatusCode) diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 5824eb060..536505817 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -10,8 +10,8 @@ import ( "testing" "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" @@ -22,21 +22,14 @@ type LAPI struct { router *gin.Engine loginResp models.WatcherAuthResponse bouncerKey string - t *testing.T DBConfig *csconfig.DatabaseCfg } func SetupLAPITest(t *testing.T) LAPI { t.Helper() - router, loginResp, config, err := InitMachineTest(t) - if err != nil { - t.Fatal(err) - } + router, loginResp, config := InitMachineTest(t) - APIKey, err := CreateTestBouncer(config.API.Server.DbConfig) - if err != nil { - t.Fatal(err) - } + APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) return LAPI{ router: router, @@ -46,24 +39,23 @@ func SetupLAPITest(t *testing.T) LAPI { } } -func (l *LAPI) InsertAlertFromFile(path string) *httptest.ResponseRecorder { - alertReader := GetAlertReaderFromFile(path) - return l.RecordResponse(http.MethodPost, "/v1/alerts", alertReader, "password") +func (l *LAPI) InsertAlertFromFile(t *testing.T, path string) *httptest.ResponseRecorder { + alertReader := GetAlertReaderFromFile(t, path) + return l.RecordResponse(t, http.MethodPost, "/v1/alerts", alertReader, "password") } -func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { +func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { w := httptest.NewRecorder() req, err := http.NewRequest(verb, url, body) - if err != nil { - l.t.Fatal(err) - } + require.NoError(t, err) - if authType == "apikey" { + switch authType { + case "apikey": req.Header.Add("X-Api-Key", l.bouncerKey) - } else if authType == "password" { + case "password": AddAuthHeaders(req, l.loginResp) - } else { - l.t.Fatal("auth type not supported") + default: + t.Fatal("auth type not supported") } l.router.ServeHTTP(w, req) @@ -71,29 +63,16 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut return w } -func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config, error) { - router, config, err := NewAPITest(t) - if err != nil { - return nil, models.WatcherAuthResponse{}, config, fmt.Errorf("unable to run local API: %s", err) - } +func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) { + router, config := NewAPITest(t) + loginResp := LoginToTestAPI(t, router, config) - loginResp, err := LoginToTestAPI(router, config) - if err != nil { - return nil, models.WatcherAuthResponse{}, config, err - } - - return router, loginResp, config, nil + return router, loginResp, config } -func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherAuthResponse, error) { - body, err := CreateTestMachine(router) - if err != nil { - return models.WatcherAuthResponse{}, err - } - err = ValidateMachine("test", config.API.Server.DbConfig) - if err != nil { - log.Fatalln(err) - } +func LoginToTestAPI(t *testing.T, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse { + body := CreateTestMachine(t, router) + ValidateMachine(t, "test", config.API.Server.DbConfig) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body)) @@ -101,12 +80,10 @@ func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherA router.ServeHTTP(w, req) loginResp := models.WatcherAuthResponse{} - err = json.NewDecoder(w.Body).Decode(&loginResp) - if err != nil { - return models.WatcherAuthResponse{}, err - } + err := json.NewDecoder(w.Body).Decode(&loginResp) + require.NoError(t, err) - return loginResp, nil + return loginResp } func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthResponse) { @@ -116,17 +93,17 @@ func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthRespon func TestSimulatedAlert(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_minibulk+simul.json") - alertContent := GetAlertReaderFromFile("./tests/alert_minibulk+simul.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk+simul.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_minibulk+simul.json") //exclude decision in simulation mode - w := lapi.RecordResponse("GET", "/v1/alerts?simulated=false", alertContent, "password") + w := lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) //include decision in simulation mode - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", alertContent, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) @@ -136,35 +113,29 @@ func TestCreateAlert(t *testing.T) { lapi := SetupLAPITest(t) // Create Alert with invalid format - w := lapi.RecordResponse(http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") + w := lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password") assert.Equal(t, 400, w.Code) - assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) + assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create Alert with invalid input - alertContent := GetAlertReaderFromFile("./tests/invalidAlert_sample.json") + alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json") - w = lapi.RecordResponse(http.MethodPost, "/v1/alerts", alertContent, "password") + w = lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"validation failure list:\\n0.scenario in body is required\\n0.scenario_hash in body is required\\n0.scenario_version in body is required\\n0.simulated in body is required\\n0.source in body is required\"}", w.Body.String()) + assert.Equal(t, `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, w.Body.String()) // Create Valid Alert - w = lapi.InsertAlertFromFile("./tests/alert_sample.json") + w = lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assert.Equal(t, 201, w.Code) - assert.Equal(t, "[\"1\"]", w.Body.String()) + assert.Equal(t, `["1"]`, w.Body.String()) } func TestCreateAlertChannels(t *testing.T) { - apiServer, config, err := NewAPIServer(t) - if err != nil { - log.Fatalln(err) - } + apiServer, config := NewAPIServer(t) apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert) apiServer.InitController() - loginResp, err := LoginToTestAPI(apiServer.router, config) - if err != nil { - log.Fatalln(err) - } + loginResp := LoginToTestAPI(t, apiServer.router, config) lapi := LAPI{router: apiServer.router, loginResp: loginResp} var ( @@ -180,7 +151,7 @@ func TestCreateAlertChannels(t *testing.T) { wg.Done() }() - go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json") + go lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json") wg.Wait() assert.Len(t, pd.Alert.Decisions, 1) apiServer.Close() @@ -188,18 +159,18 @@ func TestCreateAlertChannels(t *testing.T) { func TestAlertListFilters(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json") - alertContent := GetAlertReaderFromFile("./tests/alert_ssh-bf.json") + lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_ssh-bf.json") //bad filter - w := lapi.RecordResponse("GET", "/v1/alerts?test=test", alertContent, "password") + w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", alertContent, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) + assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) //get without filters - w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) //check alert and decision assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") @@ -207,149 +178,149 @@ func TestAlertListFilters(t *testing.T) { //test decision_type filter (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ban", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test decision_type filter (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test scope (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?scope=Ip", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=Ip", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test scope (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?scope=rarara", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=rarara", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test scenario (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test scenario (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test ip (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test ip (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test ip (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=gruueq", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) //test range (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test range - w = lapi.RecordResponse("GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test range (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?range=ratata", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=ratata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) //test since (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1h", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1h", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test since (ok but yields no results) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test since (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) //test until (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1ns", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test until (ok but no return) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1m", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1m", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test until (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1zuzu", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) //test simulated (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test simulated (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=false", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test has active decision - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test has active decision - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test has active decision (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) } @@ -357,32 +328,32 @@ func TestAlertListFilters(t *testing.T) { func TestAlertBulkInsert(t *testing.T) { lapi := SetupLAPITest(t) //insert a bulk of 20 alerts to trigger bulk insert - lapi.InsertAlertFromFile("./tests/alert_bulk.json") - alertContent := GetAlertReaderFromFile("./tests/alert_bulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_bulk.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_bulk.json") - w := lapi.RecordResponse("GET", "/v1/alerts", alertContent, "password") + w := lapi.RecordResponse(t, "GET", "/v1/alerts", alertContent, "password") assert.Equal(t, 200, w.Code) } func TestListAlert(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // List Alert with invalid filter - w := lapi.RecordResponse("GET", "/v1/alerts?test=test", emptyBody, "password") + w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", emptyBody, "password") assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) + assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String()) // List Alert - w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password") + w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "crowdsecurity/test") } func TestCreateAlertErrors(t *testing.T) { lapi := SetupLAPITest(t) - alertContent := GetAlertReaderFromFile("./tests/alert_sample.json") + alertContent := GetAlertReaderFromFile(t, "./tests/alert_sample.json") //test invalid bearer w := httptest.NewRecorder() @@ -403,7 +374,7 @@ func TestCreateAlertErrors(t *testing.T) { func TestDeleteAlert(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() @@ -426,7 +397,7 @@ func TestDeleteAlert(t *testing.T) { func TestDeleteAlertByID(t *testing.T) { lapi := SetupLAPITest(t) - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Fail Delete Alert w := httptest.NewRecorder() @@ -454,25 +425,18 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"} cfg.API.Server.ListenURI = "::8080" server, err := NewServer(cfg.API.Server) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + err = server.InitController() - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + router, err := server.Router() - if err != nil { - log.Fatal(err) - } - loginResp, err := LoginToTestAPI(router, cfg) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) + + loginResp := LoginToTestAPI(t, router, cfg) lapi := LAPI{ router: router, loginResp: loginResp, - t: t, } assertAlertDeleteFailedFromIP := func(ip string) { @@ -498,17 +462,17 @@ func TestDeleteAlertTrustedIPS(t *testing.T) { assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeleteFailedFromIP("4.3.2.1") assertAlertDeletedFromIP("1.2.3.4") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.0") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.1") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeletedFromIP("1.2.4.255") - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") assertAlertDeletedFromIP("127.0.0.1") } diff --git a/pkg/apiserver/api_key_test.go b/pkg/apiserver/api_key_test.go index df61e0b26..883ff2129 100644 --- a/pkg/apiserver/api_key_test.go +++ b/pkg/apiserver/api_key_test.go @@ -6,20 +6,14 @@ import ( "strings" "testing" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) func TestAPIKey(t *testing.T) { - router, config, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITest(t) + + APIKey := CreateTestBouncer(t, config.API.Server.DbConfig) - APIKey, err := CreateTestBouncer(config.API.Server.DbConfig) - if err != nil { - log.Fatal(err) - } // Login with empty token w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader("")) @@ -27,7 +21,7 @@ func TestAPIKey(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String()) + assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with invalid token w = httptest.NewRecorder() @@ -37,7 +31,7 @@ func TestAPIKey(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String()) + assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String()) // Login with valid token w = httptest.NewRecorder() diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index 16dba1e86..74c627cd0 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -36,6 +36,7 @@ import ( func getDBClient(t *testing.T) *database.Client { t.Helper() + dbPath, err := os.CreateTemp("", "*sqlite") require.NoError(t, err) dbClient, err := database.NewClient(&csconfig.DatabaseCfg{ @@ -72,8 +73,9 @@ func getAPIC(t *testing.T) *apic { } } -func absDiff(a int, b int) (c int) { - if c = a - b; c < 0 { +func absDiff(a int, b int) int { + c := a - b + if c < 0 { return -1 * c } @@ -185,6 +187,7 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) { func TestNewAPIC(t *testing.T) { var testConfig *csconfig.OnlineApiClientCfg + setConfig := func() { testConfig = &csconfig.OnlineApiClientCfg{ Credentials: &csconfig.ApiCredentialsCfg{ @@ -199,6 +202,7 @@ func TestNewAPIC(t *testing.T) { dbClient *database.Client consoleConfig *csconfig.ConsoleConfig } + tests := []struct { name string args args @@ -374,7 +378,6 @@ func TestAPICGetMetrics(t *testing.T) { assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers) assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines) - }) } } @@ -403,6 +406,7 @@ func TestCreateAlertsForDecision(t *testing.T) { type args struct { decisions []*models.Decision } + tests := []struct { name string args args @@ -489,6 +493,7 @@ func TestFillAlertsWithDecisions(t *testing.T) { alerts []*models.Alert decisions []*models.Decision } + tests := []struct { name string args args @@ -544,26 +549,18 @@ func TestAPICWhitelists(t *testing.T) { api := getAPIC(t) //one whitelist on IP, one on CIDR api.whitelists = &csconfig.CapiWhitelist{} - ipwl1 := "9.2.3.4" - ip := net.ParseIP(ipwl1) - api.whitelists.Ips = append(api.whitelists.Ips, ip) - ipwl1 = "7.2.3.4" - ip = net.ParseIP(ipwl1) - api.whitelists.Ips = append(api.whitelists.Ips, ip) - cidrwl1 := "13.2.3.0/24" - _, tnet, err := net.ParseCIDR(cidrwl1) - if err != nil { - t.Fatalf("unable to parse cidr : %s", err) - } + api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4")) + + _, tnet, err := net.ParseCIDR("13.2.3.0/24") + require.NoError(t, err) api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) - cidrwl1 = "11.2.3.0/24" - _, tnet, err = net.ParseCIDR(cidrwl1) - if err != nil { - t.Fatalf("unable to parse cidr : %s", err) - } + + _, tnet, err = net.ParseCIDR("11.2.3.0/24") + require.NoError(t, err) api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet) + api.dbClient.Ent.Decision.Create(). SetOrigin(types.CAPIOrigin). SetType("ban"). @@ -663,12 +660,15 @@ func TestAPICWhitelists(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder( 200, "1.2.3.6", )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder( 200, "1.2.3.7", )) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -801,12 +801,15 @@ func TestAPICPullTop(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder( 200, "1.2.3.6", )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder( 200, "1.2.3.7", )) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -828,7 +831,8 @@ func TestAPICPullTop(t *testing.T) { alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background()) validDecisions := api.dbClient.Ent.Decision.Query().Where( decision.UntilGT(time.Now())). - AllX(context.Background()) + AllX(context.Background(), + ) decisionScenarioFreq := make(map[string]int) alertScenario := make(map[string]int) @@ -858,6 +862,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { httpmock.Activate() defer httpmock.DeactivateAndReset() + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder( 200, jsonMarshalX( modelscapi.GetDecisionsStreamResponse{ @@ -887,10 +892,12 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { assert.Equal(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(200, "1.2.3.4"), nil }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -916,6 +923,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { assert.NotEqual(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(304, ""), nil }) + err = api.PullTop(false) require.NoError(t, err) secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName) @@ -928,6 +936,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { httpmock.Activate() defer httpmock.DeactivateAndReset() + // create a decision about to expire. It should force fetch alertInstance := api.dbClient.Ent.Alert. Create(). @@ -975,10 +984,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) { }, ), )) + httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { assert.Equal(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(304, ""), nil }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -1005,6 +1016,7 @@ func TestAPICPullBlocklistCall(t *testing.T) { assert.Equal(t, "", req.Header.Get("If-Modified-Since")) return httpmock.NewStringResponse(200, "1.2.3.4"), nil }) + url, err := url.ParseRequestURI("http://api.crowdsec.net/") require.NoError(t, err) @@ -1073,6 +1085,7 @@ func TestAPICPush(t *testing.T) { Source: &models.Source{}, } } + return alerts }(), }, diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 11d0c3eaa..638ac2c65 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -307,6 +307,7 @@ func (s *APIServer) Run(apiReady chan bool) error { log.Errorf("capi push: %s", err) return err } + return nil }) @@ -315,6 +316,7 @@ func (s *APIServer) Run(apiReady chan bool) error { log.Errorf("capi pull: %s", err) return err } + return nil }) @@ -328,6 +330,7 @@ func (s *APIServer) Run(apiReady chan bool) error { log.Errorf("papi pull: %s", err) return err } + return nil }) @@ -336,6 +339,7 @@ func (s *APIServer) Run(apiReady chan bool) error { log.Errorf("capi decisions sync: %s", err) return err } + return nil }) } else { diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index 62a8b83dd..b7f6be5fe 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -13,11 +13,12 @@ import ( "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/cstest" + "github.com/crowdsecurity/go-cs-lib/ptr" "github.com/crowdsecurity/go-cs-lib/version" middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" @@ -63,13 +64,14 @@ func LoadTestConfig(t *testing.T) csconfig.Config { ShareCustomScenarios: new(bool), }, } + apiConfig := csconfig.APICfg{ Server: &apiServerConfig, } + config.API = &apiConfig - if err := config.API.Server.LoadProfiles(); err != nil { - log.Fatalf("failed to load profiles: %s", err) - } + err := config.API.Server.LoadProfiles() + require.NoError(t, err) return config } @@ -106,110 +108,89 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config { Server: &apiServerConfig, } config.API = &apiConfig - if err := config.API.Server.LoadProfiles(); err != nil { - log.Fatalf("failed to load profiles: %s", err) - } + err := config.API.Server.LoadProfiles() + require.NoError(t, err) return config } -func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) { +func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) { config := LoadTestConfig(t) os.Remove("./ent") + apiServer, err := NewServer(config.API.Server) - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) log.Printf("Creating new API server") gin.SetMode(gin.TestMode) - return apiServer, config, nil + return apiServer, config } -func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config, error) { - apiServer, config, err := NewAPIServer(t) - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } - err = apiServer.InitController() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } +func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) { + apiServer, config := NewAPIServer(t) + + err := apiServer.InitController() + require.NoError(t, err) + router, err := apiServer.Router() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) - return router, config, nil + return router, config } -func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config, error) { +func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) { config := LoadTestConfigForwardedFor(t) os.Remove("./ent") + apiServer, err := NewServer(config.API.Server) - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) + err = apiServer.InitController() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) log.Printf("Creating new API server") gin.SetMode(gin.TestMode) + router, err := apiServer.Router() - if err != nil { - return nil, config, fmt.Errorf("unable to run local API: %s", err) - } + require.NoError(t, err) - return router, config, nil + return router, config } -func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error { +func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) { dbClient, err := database.NewClient(config) - if err != nil { - return fmt.Errorf("unable to create new database client: %s", err) - } + require.NoError(t, err) - if err := dbClient.ValidateMachine(machineID); err != nil { - return fmt.Errorf("unable to validate machine: %s", err) - } - - return nil + err = dbClient.ValidateMachine(machineID) + require.NoError(t, err) } -func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error) { +func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string { dbClient, err := database.NewClient(config) - if err != nil { - return "", fmt.Errorf("unable to create new database client: %s", err) - } + require.NoError(t, err) + machines, err := dbClient.ListMachines() - if err != nil { - return "", fmt.Errorf("Unable to list machines: %s", err) - } + require.NoError(t, err) for _, machine := range machines { if machine.MachineId == machineID { - return machine.IpAddress, nil + return machine.IpAddress } } - return "", nil + return "" } -func GetAlertReaderFromFile(path string) *strings.Reader { +func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader { alertContentBytes, err := os.ReadFile(path) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) alerts := make([]*models.Alert, 0) - if err = json.Unmarshal(alertContentBytes, &alerts); err != nil { - log.Fatal(err) - } + err = json.Unmarshal(alertContentBytes, &alerts) + require.NoError(t, err) for _, alert := range alerts { *alert.StartAt = time.Now().UTC().Format(time.RFC3339) @@ -217,74 +198,57 @@ func GetAlertReaderFromFile(path string) *strings.Reader { } alertContent, err := json.Marshal(alerts) - if err != nil { - log.Fatal(err) - } + require.NoError(t, err) return strings.NewReader(string(alertContent)) } -func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) { +func readDecisionsGetResp(t *testing.T, resp *httptest.ResponseRecorder) ([]*models.Decision, int) { var response []*models.Decision - if resp == nil { - return nil, 0, errors.New("response is nil") - } - err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } + require.NotNil(t, resp) - return response, resp.Code, nil + err := json.Unmarshal(resp.Body.Bytes(), &response) + require.NoError(t, err) + + return response, resp.Code } -func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) { +func readDecisionsErrorResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string]string, int) { var response map[string]string - if resp == nil { - return nil, 0, errors.New("response is nil") - } - err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } + require.NotNil(t, resp) - return response, resp.Code, nil + err := json.Unmarshal(resp.Body.Bytes(), &response) + require.NoError(t, err) + + return response, resp.Code } -func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) { +func readDecisionsDeleteResp(t *testing.T, resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int) { var response models.DeleteDecisionResponse - if resp == nil { - return nil, 0, errors.New("response is nil") - } + require.NotNil(t, resp) err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } + require.NoError(t, err) - return &response, resp.Code, nil + return &response, resp.Code } -func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) { +func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int) { response := make(map[string][]*models.Decision) - if resp == nil { - return nil, 0, errors.New("response is nil") - } + require.NotNil(t, resp) err := json.Unmarshal(resp.Body.Bytes(), &response) - if err != nil { - return nil, resp.Code, err - } + require.NoError(t, err) - return response, resp.Code, nil + return response, resp.Code } -func CreateTestMachine(router *gin.Engine) (string, error) { +func CreateTestMachine(t *testing.T, router *gin.Engine) string { b, err := json.Marshal(MachineTest) - if err != nil { - return "", fmt.Errorf("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() @@ -292,26 +256,20 @@ func CreateTestMachine(router *gin.Engine) (string, error) { req.Header.Set("User-Agent", UserAgent) router.ServeHTTP(w, req) - return body, nil + return body } -func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) { +func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string { dbClient, err := database.NewClient(config) - if err != nil { - log.Fatalf("unable to create new database client: %s", err) - } + require.NoError(t, err) apiKey, err := middlewares.GenerateAPIKey(keyLength) - if err != nil { - return "", fmt.Errorf("unable to generate api key: %s", err) - } + require.NoError(t, err) _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) - if err != nil { - return "", fmt.Errorf("unable to create blocker: %s", err) - } + require.NoError(t, err) - return apiKey, nil + return apiKey } func TestWithWrongDBConfig(t *testing.T) { @@ -334,10 +292,7 @@ func TestWithWrongFlushConfig(t *testing.T) { } func TestUnknownPath(t *testing.T) { - router, _, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, _ := NewAPITest(t) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/test", nil) @@ -384,24 +339,17 @@ func TestLoggingDebugToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - lvl := log.DebugLevel expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) expectedLines := []string{"/test42"} - cfg.LogLevel = &lvl + cfg.LogLevel = ptr.Of(log.DebugLevel) // Configure logging - if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil { - t.Fatal(err) - } + err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) + require.NoError(t, err) api, err := NewServer(&cfg) - if err != nil { - t.Fatalf("failed to create api : %s", err) - } - - if api == nil { - t.Fatalf("failed to create api #2 is nbill") - } + require.NoError(t, err) + require.NotNil(t, api) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/test42", nil) @@ -413,14 +361,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) { //check file content data, err := os.ReadFile(expectedFile) - if err != nil { - t.Fatalf("failed to read file : %s", err) - } + require.NoError(t, err) for _, expectedStr := range expectedLines { - if !strings.Contains(string(data), expectedStr) { - t.Fatalf("expected %s in %s", expectedStr, string(data)) - } + assert.Contains(t, string(data), expectedStr) } } @@ -446,35 +390,29 @@ func TestLoggingErrorToFileConfig(t *testing.T) { LogDir: tempDir, DbConfig: &dbconfig, } - lvl := log.ErrorLevel expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir) - cfg.LogLevel = &lvl + cfg.LogLevel = ptr.Of(log.ErrorLevel) // Configure logging - if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil { - t.Fatal(err) - } - api, err := NewServer(&cfg) - if err != nil { - t.Fatalf("failed to create api : %s", err) - } + err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false) + require.NoError(t, err) - if api == nil { - t.Fatalf("failed to create api #2 is nbill") - } + api, err := NewServer(&cfg) + require.NoError(t, err) + require.NotNil(t, api) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodGet, "/test42", nil) req.Header.Set("User-Agent", UserAgent) api.router.ServeHTTP(w, req) - assert.Equal(t, 404, w.Code) + assert.Equal(t, http.StatusNotFound, w.Code) //wait for the request to happen time.Sleep(500 * time.Millisecond) //check file content x, err := os.ReadFile(expectedFile) - if err == nil && len(x) > 0 { - t.Fatalf("file should be empty, got '%s'", x) + if err == nil { + require.Empty(t, x) } os.Remove("./crowdsec.log") diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index 5794b40d3..bab196512 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -55,6 +55,7 @@ func serveHealth() http.HandlerFunc { // no caching required health.WithDisabledCache(), ) + return health.NewHandler(checker) } @@ -76,6 +77,7 @@ func (c *Controller) NewV1() error { if err != nil { return err } + c.Router.GET("/health", gin.WrapF(serveHealth())) c.Router.Use(v1.PrometheusMiddleware()) c.Router.HandleMethodNotAllowed = true @@ -104,7 +106,6 @@ func (c *Controller) NewV1() error { jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions) jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById) jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat) - } apiKeyAuth := groupV1.Group("") diff --git a/pkg/apiserver/controllers/v1/alerts.go b/pkg/apiserver/controllers/v1/alerts.go index 10841ce45..424c20af6 100644 --- a/pkg/apiserver/controllers/v1/alerts.go +++ b/pkg/apiserver/controllers/v1/alerts.go @@ -22,7 +22,6 @@ import ( ) func FormatOneAlert(alert *ent.Alert) *models.Alert { - var outputAlert models.Alert startAt := alert.StartedAt.String() StopAt := alert.StoppedAt.String() @@ -31,7 +30,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { machineID = alert.Edges.Owner.MachineId } - outputAlert = models.Alert{ + outputAlert := models.Alert{ ID: int64(alert.ID), MachineID: machineID, CreatedAt: alert.CreatedAt.Format(time.RFC3339), @@ -58,23 +57,27 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { Longitude: alert.SourceLongitude, }, } + for _, eventItem := range alert.Edges.Events { var Metas models.Meta timestamp := eventItem.Time.String() if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil { log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err) } + outputAlert.Events = append(outputAlert.Events, &models.Event{ Timestamp: ×tamp, Meta: Metas, }) } + for _, metaItem := range alert.Edges.Metas { outputAlert.Meta = append(outputAlert.Meta, &models.MetaItems0{ Key: metaItem.Key, Value: metaItem.Value, }) } + for _, decisionItem := range alert.Edges.Decisions { duration := decisionItem.Until.Sub(time.Now().UTC()).String() outputAlert.Decisions = append(outputAlert.Decisions, &models.Decision{ @@ -88,6 +91,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert { ID: int64(decisionItem.ID), }) } + return &outputAlert } @@ -97,6 +101,7 @@ func FormatAlerts(result []*ent.Alert) models.AddAlertsRequest { for _, alertItem := range result { data = append(data, FormatOneAlert(alertItem)) } + return data } @@ -107,6 +112,7 @@ func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uin select { case c.PluginChannel <- csplugin.ProfileAlert{ProfileID: profileID, Alert: alert}: log.Debugf("alert sent to Plugin channel") + break RETRY default: log.Warningf("Cannot send alert to Plugin channel (try: %d)", try) @@ -133,7 +139,6 @@ func normalizeScope(scope string) string { // CreateAlert writes the alerts received in the body to the database func (c *Controller) CreateAlert(gctx *gin.Context) { - var input models.AddAlertsRequest claims := jwt.ExtractClaims(gctx) @@ -144,13 +149,16 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()}) return } + if err := input.Validate(strfmt.Default); err != nil { c.HandleDBErrors(gctx, err) return } + stopFlush := false + for _, alert := range input { - //normalize scope for alert.Source and decisions + // normalize scope for alert.Source and decisions if alert.Source.Scope != nil { *alert.Source.Scope = normalizeScope(*alert.Source.Scope) } @@ -161,15 +169,16 @@ func (c *Controller) CreateAlert(gctx *gin.Context) { } alert.MachineID = machineID - //generate uuid here for alert + // generate uuid here for alert alert.UUID = uuid.NewString() - //if coming from cscli, alert already has decisions + // if coming from cscli, alert already has decisions if len(alert.Decisions) != 0 { //alert already has a decision (cscli decisions add etc.), generate uuid here for _, decision := range alert.Decisions { decision.UUID = uuid.NewString() } + for pIdx, profile := range c.Profiles { _, matched, err := profile.EvaluateProfile(alert) if err != nil { diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index 465accbac..e4c9dda47 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const ( @@ -16,23 +15,22 @@ func TestDeleteDecisionRange(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse("DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by range - w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) // delete by range : ensure it was already deleted - w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) } @@ -41,23 +39,23 @@ func TestDeleteDecisionFilter(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse("DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by ip good - w = lapi.RecordResponse("DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) // delete by scope/value - w = lapi.RecordResponse("DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } @@ -66,17 +64,17 @@ func TestDeleteDecisionFilterByScenario(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") // delete by wrong scenario - w := lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by scenario good - w = lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) } @@ -85,14 +83,13 @@ func TestGetDecisionFilters(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_minibulk.json") + lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json") // Get Decision - w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err := readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 2) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) @@ -104,10 +101,9 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : type filter - w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?type=ban", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 2) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) @@ -122,10 +118,9 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : scope/value - w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 1) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) @@ -137,10 +132,9 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : ip filter - w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 1) assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario) @@ -151,10 +145,9 @@ func TestGetDecisionFilters(t *testing.T) { // assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`) // Get decision : by range - w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err = readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code = readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 2) assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179") @@ -165,13 +158,12 @@ func TestGetDecision(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Get Decision - w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY) + w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) - decisions, code, err := readDecisionsGetResp(w) - require.NoError(t, err) + decisions, code := readDecisionsGetResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions, 3) /*decisions get doesn't perform deduplication*/ @@ -188,7 +180,7 @@ func TestGetDecision(t *testing.T) { assert.Equal(t, int64(3), decisions[2].ID) // Get Decision with invalid filter. It should ignore this filter - w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY) + w = lapi.RecordResponse(t, "GET", "/v1/decisions?test=test", emptyBody, APIKEY) assert.Equal(t, 200, w.Code) assert.Len(t, decisions, 3) } @@ -197,49 +189,43 @@ func TestDeleteDecisionByID(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") //Have one alerts - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err := readDecisionsStreamResp(w) - require.NoError(t, err) + w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) // Delete alert with Invalid ID - w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/test", emptyBody, PASSWORD) assert.Equal(t, 400, w.Code) - errResp, _, err := readDecisionsErrorResp(w) - require.NoError(t, err) + errResp, _ := readDecisionsErrorResp(t, w) assert.Equal(t, "decision_id must be valid integer", errResp["message"]) // Delete alert with ID that not exist - w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/100", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) - errResp, _, err = readDecisionsErrorResp(w) - require.NoError(t, err) + errResp, _ = readDecisionsErrorResp(t, w) assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"]) //Have one alerts - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) // Delete alert with valid ID - w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - resp, _, err := readDecisionsDeleteResp(w) - require.NoError(t, err) + resp, _ := readDecisionsDeleteResp(t, w) assert.Equal(t, "1", resp.NbDeleted) //Have one alert (because we delete an alert that has dup targets) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) @@ -249,20 +235,18 @@ func TestDeleteDecision(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Delete alert with Invalid filter - w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) + w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD) assert.Equal(t, 500, w.Code) - errResp, _, err := readDecisionsErrorResp(w) - require.NoError(t, err) + errResp, _ := readDecisionsErrorResp(t, w) assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"]) // Delete all alert - w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) - resp, _, err := readDecisionsDeleteResp(w) - require.NoError(t, err) + resp, _ := readDecisionsDeleteResp(t, w) assert.Equal(t, "3", resp.NbDeleted) } @@ -271,12 +255,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { lapi := SetupLAPITest(t) // Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3 - lapi.InsertAlertFromFile("./tests/alert_sample.json") + lapi.InsertAlertFromFile(t, "./tests/alert_sample.json") // Get Stream, we only get one decision (the longest one) - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err := readDecisionsStreamResp(w) - require.NoError(t, err) + w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code := readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) @@ -285,13 +268,12 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip - w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/3", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // Get Stream, we only get one decision (the longest one, id=2) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) @@ -300,13 +282,12 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete another decision, yet don't receive it in stream, since there's another decision on same IP - w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/2", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) // And get the remaining decision (1) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Empty(t, decisions["deleted"]) assert.Len(t, decisions["new"], 1) @@ -315,13 +296,12 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value) // We delete the last decision, we receive the delete order - w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD) + w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD) assert.Equal(t, 200, w.Code) //and now we only get a deleted decision - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) - decisions, code, err = readDecisionsStreamResp(w) - require.NoError(t, err) + w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY) + decisions, code = readDecisionsStreamResp(t, w) assert.Equal(t, 200, code) assert.Len(t, decisions["deleted"], 1) assert.Equal(t, int64(1), decisions["deleted"][0].ID) diff --git a/pkg/apiserver/heartbeat_test.go b/pkg/apiserver/heartbeat_test.go index 0082f23ec..fbf01c7fb 100644 --- a/pkg/apiserver/heartbeat_test.go +++ b/pkg/apiserver/heartbeat_test.go @@ -10,9 +10,9 @@ import ( func TestHeartBeat(t *testing.T) { lapi := SetupLAPITest(t) - w := lapi.RecordResponse(http.MethodGet, "/v1/heartbeat", emptyBody, "password") + w := lapi.RecordResponse(t, http.MethodGet, "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 200, w.Code) - w = lapi.RecordResponse("POST", "/v1/heartbeat", emptyBody, "password") + w = lapi.RecordResponse(t, "POST", "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 405, w.Code) } diff --git a/pkg/apiserver/jwt_test.go b/pkg/apiserver/jwt_test.go index 886962250..58f66cfc7 100644 --- a/pkg/apiserver/jwt_test.go +++ b/pkg/apiserver/jwt_test.go @@ -6,20 +6,13 @@ import ( "strings" "testing" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" ) func TestLogin(t *testing.T) { - router, config, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITest(t) - body, err := CreateTestMachine(router) - if err != nil { - log.Fatalln(err) - } + body := CreateTestMachine(t, router) // Login with machine not validated yet w := httptest.NewRecorder() @@ -28,16 +21,16 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"machine test not validated\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"machine test not validated"}`, w.Body.String()) // Login with machine not exist w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\", \"password\": \"test1\"}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"ent: machine not found\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String()) // Login with invalid body w = httptest.NewRecorder() @@ -46,31 +39,28 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"missing: invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"missing: invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Login with invalid format w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\"}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"validation failure list:\\npassword in body is required\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String()) //Validate machine - err = ValidateMachine("test", config.API.Server.DbConfig) - if err != nil { - log.Fatalln(err) - } + ValidateMachine(t, "test", config.API.Server.DbConfig) // Login with invalid password w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test1\"}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 401, w.Code) - assert.Equal(t, "{\"code\":401,\"message\":\"incorrect Username or Password\"}", w.Body.String()) + assert.Equal(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String()) // Login with valid machine w = httptest.NewRecorder() @@ -79,16 +69,16 @@ func TestLogin(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "\"token\"") - assert.Contains(t, w.Body.String(), "\"expire\"") + assert.Contains(t, w.Body.String(), `"token"`) + assert.Contains(t, w.Body.String(), `"expire"`) // Login with valid machine + scenarios w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test\", \"scenarios\": [\"crowdsecurity/test\", \"crowdsecurity/test2\"]}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "\"token\"") - assert.Contains(t, w.Body.String(), "\"expire\"") + assert.Contains(t, w.Body.String(), `"token"`) + assert.Contains(t, w.Body.String(), `"expire"`) } diff --git a/pkg/apiserver/machines_test.go b/pkg/apiserver/machines_test.go index 6ac016404..08efa91c6 100644 --- a/pkg/apiserver/machines_test.go +++ b/pkg/apiserver/machines_test.go @@ -7,15 +7,12 @@ import ( "strings" "testing" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCreateMachine(t *testing.T) { - router, _, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, _ := NewAPITest(t) // Create machine with invalid format w := httptest.NewRecorder() @@ -24,22 +21,21 @@ func TestCreateMachine(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 400, w.Code) - assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) + assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String()) // Create machine with invalid input w = httptest.NewRecorder() - req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("{\"test\": \"test\"}")) + req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`)) req.Header.Add("User-Agent", UserAgent) router.ServeHTTP(w, req) assert.Equal(t, 500, w.Code) - assert.Equal(t, "{\"message\":\"validation failure list:\\nmachine_id in body is required\\npassword in body is required\"}", w.Body.String()) + assert.Equal(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String()) // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w = httptest.NewRecorder() @@ -52,16 +48,12 @@ func TestCreateMachine(t *testing.T) { } func TestCreateMachineWithForwardedFor(t *testing.T) { - router, config, err := NewAPITestForwardedFor(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITestForwardedFor(t) router.TrustedPlatform = "X-Real-IP" // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() @@ -73,25 +65,18 @@ func TestCreateMachineWithForwardedFor(t *testing.T) { assert.Equal(t, 201, w.Code) assert.Equal(t, "", w.Body.String()) - ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig) - if err != nil { - log.Fatalf("Could not get machine IP : %s", err) - } + ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) assert.Equal(t, "1.1.1.1", ip) } func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { - router, config, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITest(t) // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() @@ -103,26 +88,20 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) { assert.Equal(t, 201, w.Code) assert.Equal(t, "", w.Body.String()) - ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig) - if err != nil { - log.Fatalf("Could not get machine IP : %s", err) - } + ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) + //For some reason, the IP is empty when running tests //if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineWithoutForwardedFor(t *testing.T) { - router, config, err := NewAPITestForwardedFor(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, config := NewAPITestForwardedFor(t) // Create machine b, err := json.Marshal(MachineTest) - if err != nil { - log.Fatal("unable to marshal MachineTest") - } + require.NoError(t, err) + body := string(b) w := httptest.NewRecorder() @@ -133,25 +112,17 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) { assert.Equal(t, 201, w.Code) assert.Equal(t, "", w.Body.String()) - ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig) - if err != nil { - log.Fatalf("Could not get machine IP : %s", err) - } + ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig) + //For some reason, the IP is empty when running tests //if no forwarded-for headers are present assert.Equal(t, "", ip) } func TestCreateMachineAlreadyExist(t *testing.T) { - router, _, err := NewAPITest(t) - if err != nil { - log.Fatalf("unable to run local API: %s", err) - } + router, _ := NewAPITest(t) - body, err := CreateTestMachine(router) - if err != nil { - log.Fatalln(err) - } + body := CreateTestMachine(t, router) w := httptest.NewRecorder() req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body)) @@ -164,5 +135,5 @@ func TestCreateMachineAlreadyExist(t *testing.T) { router.ServeHTTP(w, req) assert.Equal(t, 403, w.Code) - assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String()) + assert.Equal(t, `{"message":"user 'test': user already exist"}`, w.Body.String()) } diff --git a/pkg/appsec/appsec_rule/appsec_rule.go b/pkg/appsec/appsec_rule/appsec_rule.go index c011e58fb..4bc46ef50 100644 --- a/pkg/appsec/appsec_rule/appsec_rule.go +++ b/pkg/appsec/appsec_rule/appsec_rule.go @@ -28,6 +28,7 @@ rules: type match struct { Type string `yaml:"type"` Value string `yaml:"value"` + Not bool `yaml:"not,omitempty"` } type CustomRule struct { @@ -40,7 +41,8 @@ type CustomRule struct { Transform []string `yaml:"transform"` //t:lowercase, t:uppercase, etc And []CustomRule `yaml:"and,omitempty"` Or []CustomRule `yaml:"or,omitempty"` - BodyType string `yaml:"body_type,omitempty"` + + BodyType string `yaml:"body_type,omitempty"` } func (v *CustomRule) Convert(ruleType string, appsecRuleName string) (string, []uint32, error) { diff --git a/pkg/appsec/appsec_rule/modsec_rule_test.go b/pkg/appsec/appsec_rule/modsec_rule_test.go index d919dce25..80411411d 100644 --- a/pkg/appsec/appsec_rule/modsec_rule_test.go +++ b/pkg/appsec/appsec_rule/modsec_rule_test.go @@ -8,6 +8,16 @@ func TestVPatchRuleString(t *testing.T) { rule CustomRule expected string }{ + { + name: "Collection count", + rule: CustomRule{ + Zones: []string{"ARGS"}, + Variables: []string{"foo"}, + Match: match{Type: "eq", Value: "1"}, + Transform: []string{"count"}, + }, + expected: `SecRule &ARGS_GET:foo "@eq 1" "id:853070236,phase:2,deny,log,msg:'Collection count',tag:'crowdsec-Collection count'"`, + }, { name: "Base Rule", rule: CustomRule{ @@ -18,6 +28,32 @@ func TestVPatchRuleString(t *testing.T) { }, expected: `SecRule ARGS_GET:foo "@rx [^a-zA-Z]" "id:2203944045,phase:2,deny,log,msg:'Base Rule',tag:'crowdsec-Base Rule',t:lowercase"`, }, + { + name: "One zone, multi var", + rule: CustomRule{ + Zones: []string{"ARGS"}, + Variables: []string{"foo", "bar"}, + Match: match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + expected: `SecRule ARGS_GET:foo|ARGS_GET:bar "@rx [^a-zA-Z]" "id:385719930,phase:2,deny,log,msg:'One zone, multi var',tag:'crowdsec-One zone, multi var',t:lowercase"`, + }, + { + name: "Base Rule #2", + rule: CustomRule{ + Zones: []string{"METHOD"}, + Match: match{Type: "startsWith", Value: "toto"}, + }, + expected: `SecRule REQUEST_METHOD "@beginsWith toto" "id:2759779019,phase:2,deny,log,msg:'Base Rule #2',tag:'crowdsec-Base Rule #2'"`, + }, + { + name: "Base Negative Rule", + rule: CustomRule{ + Zones: []string{"METHOD"}, + Match: match{Type: "startsWith", Value: "toto", Not: true}, + }, + expected: `SecRule REQUEST_METHOD "!@beginsWith toto" "id:3966251995,phase:2,deny,log,msg:'Base Negative Rule',tag:'crowdsec-Base Negative Rule'"`, + }, { name: "Multiple Zones", rule: CustomRule{ @@ -28,6 +64,25 @@ func TestVPatchRuleString(t *testing.T) { }, expected: `SecRule ARGS_GET:foo|ARGS_POST:foo "@rx [^a-zA-Z]" "id:3387135861,phase:2,deny,log,msg:'Multiple Zones',tag:'crowdsec-Multiple Zones',t:lowercase"`, }, + { + name: "Multiple Zones Multi Var", + rule: CustomRule{ + Zones: []string{"ARGS", "BODY_ARGS"}, + Variables: []string{"foo", "bar"}, + Match: match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + expected: `SecRule ARGS_GET:foo|ARGS_GET:bar|ARGS_POST:foo|ARGS_POST:bar "@rx [^a-zA-Z]" "id:1119773585,phase:2,deny,log,msg:'Multiple Zones Multi Var',tag:'crowdsec-Multiple Zones Multi Var',t:lowercase"`, + }, + { + name: "Multiple Zones No Vars", + rule: CustomRule{ + Zones: []string{"ARGS", "BODY_ARGS"}, + Match: match{Type: "regex", Value: "[^a-zA-Z]"}, + Transform: []string{"lowercase"}, + }, + expected: `SecRule ARGS_GET|ARGS_POST "@rx [^a-zA-Z]" "id:2020110336,phase:2,deny,log,msg:'Multiple Zones No Vars',tag:'crowdsec-Multiple Zones No Vars',t:lowercase"`, + }, { name: "Basic AND", rule: CustomRule{ diff --git a/pkg/appsec/appsec_rule/modsecurity.go b/pkg/appsec/appsec_rule/modsecurity.go index 7b98206c6..0b117cd77 100644 --- a/pkg/appsec/appsec_rule/modsecurity.go +++ b/pkg/appsec/appsec_rule/modsecurity.go @@ -44,6 +44,7 @@ var matchMap map[string]string = map[string]string{ "lt": "@lt", "gte": "@ge", "lte": "@le", + "eq": "@eq", } var bodyTypeMatch map[string]string = map[string]string{ @@ -121,7 +122,20 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an return ret, nil } + zone_prefix := "" + variable_prefix := "" + if rule.Transform != nil { + for tidx, transform := range rule.Transform { + if transform == "count" { + zone_prefix = "&" + rule.Transform[tidx] = "" + } + } + } for idx, zone := range rule.Zones { + if idx > 0 { + r.WriteByte('|') + } mappedZone, ok := zonesMap[zone] if !ok { return nil, fmt.Errorf("unknown zone '%s'", zone) @@ -130,10 +144,10 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an r.WriteString(mappedZone) } else { for j, variable := range rule.Variables { - if idx > 0 || j > 0 { + if j > 0 { r.WriteByte('|') } - r.WriteString(fmt.Sprintf("%s:%s", mappedZone, variable)) + r.WriteString(fmt.Sprintf("%s%s:%s%s", zone_prefix, mappedZone, variable_prefix, variable)) } } } @@ -141,7 +155,11 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an if rule.Match.Type != "" { if match, ok := matchMap[rule.Match.Type]; ok { - r.WriteString(fmt.Sprintf(`"%s %s"`, match, rule.Match.Value)) + prefix := "" + if rule.Match.Not { + prefix = "!" + } + r.WriteString(fmt.Sprintf(`"%s%s %s"`, prefix, match, rule.Match.Value)) } else { return nil, fmt.Errorf("unknown match type '%s'", rule.Match.Type) } @@ -152,6 +170,9 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an if rule.Transform != nil { for _, transform := range rule.Transform { + if transform == "" { + continue + } r.WriteByte(',') if mappedTransform, ok := transformMap[transform]; ok { r.WriteString(mappedTransform) diff --git a/pkg/appsec/request.go b/pkg/appsec/request.go index f9ff4c169..c82529acc 100644 --- a/pkg/appsec/request.go +++ b/pkg/appsec/request.go @@ -11,6 +11,7 @@ import ( "regexp" "github.com/google/uuid" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" ) @@ -267,7 +268,7 @@ func (r *ReqDumpFilter) ToJSON() error { } // Generate a ParsedRequest from a http.Request. ParsedRequest can be consumed by the App security Engine -func NewParsedRequestFromRequest(r *http.Request) (ParsedRequest, error) { +func NewParsedRequestFromRequest(r *http.Request, logger *logrus.Entry) (ParsedRequest, error) { var err error contentLength := r.ContentLength if contentLength < 0 { @@ -282,26 +283,23 @@ func NewParsedRequestFromRequest(r *http.Request) (ParsedRequest, error) { } } - // the real source of the request is set in 'x-client-ip' clientIP := r.Header.Get(IPHeaderName) if clientIP == "" { return ParsedRequest{}, fmt.Errorf("missing '%s' header", IPHeaderName) } - // the real target Host of the request is set in 'x-client-host' - clientHost := r.Header.Get(HostHeaderName) - if clientHost == "" { - return ParsedRequest{}, fmt.Errorf("missing '%s' header", HostHeaderName) - } - // the real URI of the request is set in 'x-client-uri' + clientURI := r.Header.Get(URIHeaderName) if clientURI == "" { return ParsedRequest{}, fmt.Errorf("missing '%s' header", URIHeaderName) } - // the real VERB of the request is set in 'x-client-uri' clientMethod := r.Header.Get(VerbHeaderName) if clientMethod == "" { return ParsedRequest{}, fmt.Errorf("missing '%s' header", VerbHeaderName) } + clientHost := r.Header.Get(HostHeaderName) + if clientHost == "" { //this might be empty + logger.Debugf("missing '%s' header", HostHeaderName) + } // delete those headers before coraza process the request delete(r.Header, IPHeaderName) diff --git a/pkg/csconfig/api.go b/pkg/csconfig/api.go index 07b8d154c..06b3d3828 100644 --- a/pkg/csconfig/api.go +++ b/pkg/csconfig/api.go @@ -61,36 +61,51 @@ func (a *CTICfg) Load() error { if a.Key == nil { *a.Enabled = false } + if a.Key != nil && *a.Key == "" { return fmt.Errorf("empty cti key") } + if a.Enabled == nil { a.Enabled = new(bool) *a.Enabled = true } + if a.CacheTimeout == nil { a.CacheTimeout = new(time.Duration) *a.CacheTimeout = 10 * time.Minute } + if a.CacheSize == nil { a.CacheSize = new(int) *a.CacheSize = 100 } + return nil } func (o *OnlineApiClientCfg) Load() error { o.Credentials = new(ApiCredentialsCfg) + fcontent, err := os.ReadFile(o.CredentialsFilePath) if err != nil { - return fmt.Errorf("failed to read api server credentials configuration file '%s': %w", o.CredentialsFilePath, err) + return err } + err = yaml.UnmarshalStrict(fcontent, o.Credentials) if err != nil { return fmt.Errorf("failed unmarshaling api server credentials configuration file '%s': %w", o.CredentialsFilePath, err) } - if o.Credentials.Login == "" || o.Credentials.Password == "" || o.Credentials.URL == "" { - log.Warningf("can't load CAPI credentials from '%s' (missing field)", o.CredentialsFilePath) + + switch { + case o.Credentials.Login == "": + log.Warningf("can't load CAPI credentials from '%s' (missing login field)", o.CredentialsFilePath) + o.Credentials = nil + case o.Credentials.Password == "": + log.Warningf("can't load CAPI credentials from '%s' (missing password field)", o.CredentialsFilePath) + o.Credentials = nil + case o.Credentials.URL == "": + log.Warningf("can't load CAPI credentials from '%s' (missing url field)", o.CredentialsFilePath) o.Credentials = nil } @@ -99,14 +114,17 @@ func (o *OnlineApiClientCfg) Load() error { func (l *LocalApiClientCfg) Load() error { patcher := yamlpatch.NewPatcher(l.CredentialsFilePath, ".local") + fcontent, err := patcher.MergedPatchContent() if err != nil { return err } + err = yaml.UnmarshalStrict(fcontent, &l.Credentials) if err != nil { return fmt.Errorf("failed unmarshaling api client credential configuration file '%s': %w", l.CredentialsFilePath, err) } + if l.Credentials == nil || l.Credentials.URL == "" { return fmt.Errorf("no credentials or URL found in api client configuration '%s'", l.CredentialsFilePath) } @@ -137,9 +155,11 @@ func (l *LocalApiClientCfg) Load() error { if err != nil { log.Warningf("Error loading system CA certificates: %s", err) } + if caCertPool == nil { caCertPool = x509.NewCertPool() } + caCertPool.AppendCertsFromPEM(caCert) apiclient.CaCertPool = caCertPool } @@ -160,12 +180,15 @@ func (lapiCfg *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) { trustedIPs := make([]net.IPNet, 0) for _, ip := range lapiCfg.TrustedIPs { cidr := toValidCIDR(ip) + _, ipNet, err := net.ParseCIDR(cidr) if err != nil { return nil, err } + trustedIPs = append(trustedIPs, *ipNet) } + return trustedIPs, nil } @@ -177,6 +200,7 @@ func toValidCIDR(ip string) string { if strings.Contains(ip, ":") { return ip + "/128" } + return ip + "/32" } @@ -327,24 +351,30 @@ func parseCapiWhitelists(fd io.Reader) (*CapiWhitelist, error) { if errors.Is(err, io.EOF) { return nil, fmt.Errorf("empty file") } + return nil, err } + ret := &CapiWhitelist{ Ips: make([]net.IP, len(fromCfg.Ips)), Cidrs: make([]*net.IPNet, len(fromCfg.Cidrs)), } + for idx, v := range fromCfg.Ips { ip := net.ParseIP(v) if ip == nil { return nil, fmt.Errorf("invalid IP address: %s", v) } + ret.Ips[idx] = ip } + for idx, v := range fromCfg.Cidrs { _, tnet, err := net.ParseCIDR(v) if err != nil { return nil, err } + ret.Cidrs[idx] = tnet } @@ -356,10 +386,6 @@ func (s *LocalApiServerCfg) LoadCapiWhitelists() error { return nil } - if _, err := os.Stat(s.CapiWhitelistsPath); os.IsNotExist(err) { - return fmt.Errorf("capi whitelist file '%s' does not exist", s.CapiWhitelistsPath) - } - fd, err := os.Open(s.CapiWhitelistsPath) if err != nil { return fmt.Errorf("while opening capi whitelist file: %s", err) diff --git a/pkg/csconfig/api_test.go b/pkg/csconfig/api_test.go index b39d6eccf..e1e24e2be 100644 --- a/pkg/csconfig/api_test.go +++ b/pkg/csconfig/api_test.go @@ -116,7 +116,7 @@ func TestLoadOnlineApiClientCfg(t *testing.T) { CredentialsFilePath: "./testdata/nonexist_online-api-secrets.yaml", }, expected: &ApiCredentialsCfg{}, - expectedErr: "failed to read api server credentials", + expectedErr: "open ./testdata/nonexist_online-api-secrets.yaml: " + cstest.FileNotFoundMessage, }, } @@ -254,6 +254,7 @@ func TestLoadAPIServer(t *testing.T) { func mustParseCIDRNet(t *testing.T, s string) *net.IPNet { _, ipNet, err := net.ParseCIDR(s) require.NoError(t, err) + return ipNet } diff --git a/pkg/csplugin/broker_test.go b/pkg/csplugin/broker_test.go index f41eb8031..9adb35ad7 100644 --- a/pkg/csplugin/broker_test.go +++ b/pkg/csplugin/broker_test.go @@ -149,7 +149,7 @@ func (s *PluginSuite) TestBrokerNoThreshold() { t := s.T() pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -182,7 +182,7 @@ func (s *PluginSuite) TestBrokerNoThreshold() { err = json.Unmarshal(content, &alerts) log.Printf("content-> %s", content) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) } @@ -199,7 +199,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { s.writeconfig(cfg) pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -215,11 +215,11 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() { time.Sleep(1 * time.Second) // after 1 seconds, we should have data content, err := os.ReadFile("./out") - assert.NoError(t, err) + require.NoError(t, err) var alerts []models.Alert err = json.Unmarshal(content, &alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 3) } @@ -235,7 +235,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { s.writeconfig(cfg) pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -259,7 +259,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() { var alerts []models.Alert err = json.Unmarshal(content, &alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 4) } @@ -275,7 +275,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { s.writeconfig(cfg) pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -306,11 +306,11 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() { // two notifications, one with 4 alerts, one with 2 alerts err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 4) err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 2) err = decoder.Decode(&alerts) @@ -328,7 +328,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { s.writeconfig(cfg) pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -348,7 +348,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() { var alerts []models.Alert err = json.Unmarshal(content, &alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) } @@ -358,7 +358,7 @@ func (s *PluginSuite) TestBrokerRunSimple() { t := s.T() pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -382,11 +382,11 @@ func (s *PluginSuite) TestBrokerRunSimple() { // two notifications, one alert each err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) diff --git a/pkg/csplugin/broker_win_test.go b/pkg/csplugin/broker_win_test.go index 6466e0d54..97a3ad33d 100644 --- a/pkg/csplugin/broker_win_test.go +++ b/pkg/csplugin/broker_win_test.go @@ -70,7 +70,7 @@ func (s *PluginSuite) TestBrokerRun() { t := s.T() pb, err := s.InitBroker(nil) - assert.NoError(t, err) + require.NoError(t, err) tomb := tomb.Tomb{} go pb.Run(&tomb) @@ -94,11 +94,11 @@ func (s *PluginSuite) TestBrokerRun() { // two notifications, one alert each err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, alerts, 1) err = decoder.Decode(&alerts) diff --git a/pkg/csprofiles/csprofiles.go b/pkg/csprofiles/csprofiles.go index c71291cdb..95fbb356f 100644 --- a/pkg/csprofiles/csprofiles.go +++ b/pkg/csprofiles/csprofiles.go @@ -6,7 +6,6 @@ import ( "github.com/antonmedv/expr" "github.com/antonmedv/expr/vm" - "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/crowdsecurity/crowdsec/pkg/csconfig" @@ -22,19 +21,23 @@ type Runtime struct { Logger *log.Entry `json:"-" yaml:"-"` } -var defaultDuration = "4h" +const defaultDuration = "4h" func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { var err error + profilesRuntime := make([]*Runtime, 0) for _, profile := range profilesCfg { var runtimeFilter, runtimeDurationExpr *vm.Program + runtime := &Runtime{} + xlog := log.New() if err := types.ConfigureLogger(xlog); err != nil { log.Fatalf("While creating profiles-specific logger : %s", err) } + xlog.SetLevel(log.InfoLevel) runtime.Logger = xlog.WithFields(log.Fields{ "type": "profile", @@ -43,17 +46,20 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { runtime.RuntimeFilters = make([]*vm.Program, len(profile.Filters)) runtime.Cfg = profile - if runtime.Cfg.OnSuccess != "" && runtime.Cfg.OnSuccess != "continue" && runtime.Cfg.OnSuccess != "break" { - return []*Runtime{}, fmt.Errorf("invalid 'on_success' for '%s': %s", profile.Name, runtime.Cfg.OnSuccess) - } - if runtime.Cfg.OnFailure != "" && runtime.Cfg.OnFailure != "continue" && runtime.Cfg.OnFailure != "break" && runtime.Cfg.OnFailure != "apply" { - return []*Runtime{}, fmt.Errorf("invalid 'on_failure' for '%s' : %s", profile.Name, runtime.Cfg.OnFailure) - } - for fIdx, filter := range profile.Filters { + if runtime.Cfg.OnSuccess != "" && runtime.Cfg.OnSuccess != "continue" && runtime.Cfg.OnSuccess != "break" { + return nil, fmt.Errorf("invalid 'on_success' for '%s': %s", profile.Name, runtime.Cfg.OnSuccess) + } + + if runtime.Cfg.OnFailure != "" && runtime.Cfg.OnFailure != "continue" && runtime.Cfg.OnFailure != "break" && runtime.Cfg.OnFailure != "apply" { + return nil, fmt.Errorf("invalid 'on_failure' for '%s' : %s", profile.Name, runtime.Cfg.OnFailure) + } + + for fIdx, filter := range profile.Filters { if runtimeFilter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { - return []*Runtime{}, errors.Wrapf(err, "error compiling filter of '%s'", profile.Name) + return nil, fmt.Errorf("error compiling filter of '%s': %w", profile.Name, err) } + runtime.RuntimeFilters[fIdx] = runtimeFilter if profile.Debug != nil && *profile.Debug { runtime.Logger.Logger.SetLevel(log.DebugLevel) @@ -62,8 +68,9 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { if profile.DurationExpr != "" { if runtimeDurationExpr, err = expr.Compile(profile.DurationExpr, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil { - return []*Runtime{}, errors.Wrapf(err, "error compiling duration_expr of %s", profile.Name) + return nil, fmt.Errorf("error compiling duration_expr of %s: %w", profile.Name, err) } + runtime.RuntimeDurationExpr = runtimeDurationExpr } @@ -76,14 +83,16 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) { runtime.Logger.Warningf("No duration specified for %s, using default duration %s", profile.Name, defaultDuration) duration = defaultDuration } + if _, err := time.ParseDuration(duration); err != nil { - return []*Runtime{}, errors.Wrapf(err, "error parsing duration '%s' of %s", duration, profile.Name) + return nil, fmt.Errorf("error parsing duration '%s' of %s: %w", duration, profile.Name, err) } } } profilesRuntime = append(profilesRuntime, runtime) } + return profilesRuntime, nil } @@ -110,30 +119,29 @@ func (Profile *Runtime) GenerateDecisionFromProfile(Alert *models.Alert) ([]*mod *decision.Scope = *Alert.Source.Scope } /*some fields are populated from the reference object : duration, scope, type*/ + decision.Duration = new(string) + if refDecision.Duration != nil { + *decision.Duration = *refDecision.Duration + } + if Profile.Cfg.DurationExpr != "" && Profile.RuntimeDurationExpr != nil { profileDebug := false if Profile.Cfg.Debug != nil && *Profile.Cfg.Debug { profileDebug = true } + duration, err := exprhelpers.Run(Profile.RuntimeDurationExpr, map[string]interface{}{"Alert": Alert}, Profile.Logger, profileDebug) if err != nil { Profile.Logger.Warningf("Failed to run duration_expr : %v", err) - *decision.Duration = *refDecision.Duration } else { durationStr := fmt.Sprint(duration) if _, err := time.ParseDuration(durationStr); err != nil { Profile.Logger.Warningf("Failed to parse expr duration result '%s'", duration) - *decision.Duration = *refDecision.Duration } else { *decision.Duration = durationStr } } - } else { - if refDecision.Duration == nil { - *decision.Duration = defaultDuration - } - *decision.Duration = *refDecision.Duration } decision.Type = new(string) @@ -144,13 +152,16 @@ func (Profile *Runtime) GenerateDecisionFromProfile(Alert *models.Alert) ([]*mod *decision.Value = *Alert.Source.Value decision.Origin = new(string) *decision.Origin = types.CrowdSecOrigin + if refDecision.Origin != nil { *decision.Origin = fmt.Sprintf("%s/%s", *decision.Origin, *refDecision.Origin) } + decision.Scenario = new(string) *decision.Scenario = *Alert.Scenario decisions = append(decisions, &decision) } + return decisions, nil } @@ -159,16 +170,19 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision var decisions []*models.Decision matched := false + for eIdx, expression := range Profile.RuntimeFilters { debugProfile := false if Profile.Cfg.Debug != nil && *Profile.Cfg.Debug { debugProfile = true } + output, err := exprhelpers.Run(expression, map[string]interface{}{"Alert": Alert}, Profile.Logger, debugProfile) if err != nil { - Profile.Logger.Warningf("failed to run profile expr for %s : %v", Profile.Cfg.Name, err) - return nil, matched, errors.Wrapf(err, "while running expression %s", Profile.Cfg.Filters[eIdx]) + Profile.Logger.Warningf("failed to run profile expr for %s: %v", Profile.Cfg.Name, err) + return nil, matched, fmt.Errorf("while running expression %s: %w", Profile.Cfg.Filters[eIdx], err) } + switch out := output.(type) { case bool: if out { @@ -176,7 +190,7 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision /*the expression matched, create the associated decision*/ subdecisions, err := Profile.GenerateDecisionFromProfile(Alert) if err != nil { - return nil, matched, errors.Wrapf(err, "while generating decision from profile %s", Profile.Cfg.Name) + return nil, matched, fmt.Errorf("while generating decision from profile %s: %w", Profile.Cfg.Name, err) } decisions = append(decisions, subdecisions...) @@ -189,9 +203,7 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision default: return nil, matched, fmt.Errorf("unexpected type %t (%v) while running '%s'", output, output, Profile.Cfg.Filters[eIdx]) - } - } return decisions, matched, nil diff --git a/pkg/cticlient/client_test.go b/pkg/cticlient/client_test.go index a229bde55..79406a6c2 100644 --- a/pkg/cticlient/client_test.go +++ b/pkg/cticlient/client_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/ptr" ) @@ -36,25 +37,30 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { // wip func fireHandler(req *http.Request) *http.Response { var err error + apiKey := req.Header.Get("x-api-key") if apiKey != validApiKey { log.Warningf("invalid api key: %s", apiKey) + return &http.Response{ StatusCode: http.StatusForbidden, Body: nil, Header: make(http.Header), } } + //unmarshal data if fireResponses == nil { page1, err := os.ReadFile("tests/fire-page1.json") if err != nil { panic("can't read file") } + page2, err := os.ReadFile("tests/fire-page2.json") if err != nil { panic("can't read file") } + fireResponses = []string{string(page1), string(page2)} } //let's assume we have two valid pages. @@ -70,6 +76,7 @@ func fireHandler(req *http.Request) *http.Response { //how to react if you give a page number that is too big ? if page > len(fireResponses) { log.Warningf(" page too big %d vs %d", page, len(fireResponses)) + emptyResponse := `{ "_links": { "first": { @@ -82,8 +89,10 @@ func fireHandler(req *http.Request) *http.Response { "items": [] } ` + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(emptyResponse))} } + reader := io.NopCloser(strings.NewReader(fireResponses[page-1])) //we should care about limit too return &http.Response{ @@ -106,6 +115,7 @@ func smokeHandler(req *http.Request) *http.Response { } requestedIP := strings.Split(req.URL.Path, "/")[3] + response, ok := smokeResponses[requestedIP] if !ok { return &http.Response{ @@ -135,6 +145,7 @@ func rateLimitedHandler(req *http.Request) *http.Response { Header: make(http.Header), } } + return &http.Response{ StatusCode: http.StatusTooManyRequests, Body: nil, @@ -151,7 +162,9 @@ func searchHandler(req *http.Request) *http.Response { Header: make(http.Header), } } + url, _ := url.Parse(req.URL.String()) + ipsParam := url.Query().Get("ips") if ipsParam == "" { return &http.Response{ @@ -163,6 +176,7 @@ func searchHandler(req *http.Request) *http.Response { totalIps := 0 notFound := 0 + ips := strings.Split(ipsParam, ",") for _, ip := range ips { _, ok := smokeResponses[ip] @@ -172,12 +186,15 @@ func searchHandler(req *http.Request) *http.Response { notFound++ } } + response := fmt.Sprintf(`{"total": %d, "not_found": %d, "items": [`, totalIps, notFound) for _, ip := range ips { response += smokeResponses[ip] } + response += "]}" reader := io.NopCloser(strings.NewReader(response)) + return &http.Response{ StatusCode: http.StatusOK, Body: reader, @@ -190,7 +207,7 @@ func TestBadFireAuth(t *testing.T) { Transport: RoundTripFunc(fireHandler), })) _, err := ctiClient.Fire(FireParams{}) - assert.EqualError(t, err, ErrUnauthorized.Error()) + require.EqualError(t, err, ErrUnauthorized.Error()) } func TestFireOk(t *testing.T) { @@ -198,19 +215,19 @@ func TestFireOk(t *testing.T) { Transport: RoundTripFunc(fireHandler), })) data, err := cticlient.Fire(FireParams{}) - assert.Equal(t, err, nil) - assert.Equal(t, len(data.Items), 3) - assert.Equal(t, data.Items[0].Ip, "1.2.3.4") + require.NoError(t, err) + assert.Len(t, data.Items, 3) + assert.Equal(t, "1.2.3.4", data.Items[0].Ip) //page 1 is the default data, err = cticlient.Fire(FireParams{Page: ptr.Of(1)}) - assert.Equal(t, err, nil) - assert.Equal(t, len(data.Items), 3) - assert.Equal(t, data.Items[0].Ip, "1.2.3.4") + require.NoError(t, err) + assert.Len(t, data.Items, 3) + assert.Equal(t, "1.2.3.4", data.Items[0].Ip) //page 2 data, err = cticlient.Fire(FireParams{Page: ptr.Of(2)}) - assert.Equal(t, err, nil) - assert.Equal(t, len(data.Items), 3) - assert.Equal(t, data.Items[0].Ip, "4.2.3.4") + require.NoError(t, err) + assert.Len(t, data.Items, 3) + assert.Equal(t, "4.2.3.4", data.Items[0].Ip) } func TestFirePaginator(t *testing.T) { @@ -219,17 +236,16 @@ func TestFirePaginator(t *testing.T) { })) paginator := NewFirePaginator(cticlient, FireParams{}) items, err := paginator.Next() - assert.Equal(t, err, nil) - assert.Equal(t, len(items), 3) - assert.Equal(t, items[0].Ip, "1.2.3.4") + require.NoError(t, err) + assert.Len(t, items, 3) + assert.Equal(t, "1.2.3.4", items[0].Ip) items, err = paginator.Next() - assert.Equal(t, err, nil) - assert.Equal(t, len(items), 3) - assert.Equal(t, items[0].Ip, "4.2.3.4") + require.NoError(t, err) + assert.Len(t, items, 3) + assert.Equal(t, "4.2.3.4", items[0].Ip) items, err = paginator.Next() - assert.Equal(t, err, nil) - assert.Equal(t, len(items), 0) - + require.NoError(t, err) + assert.Empty(t, items) } func TestBadSmokeAuth(t *testing.T) { @@ -237,13 +253,14 @@ func TestBadSmokeAuth(t *testing.T) { Transport: RoundTripFunc(smokeHandler), })) _, err := ctiClient.GetIPInfo("1.1.1.1") - assert.EqualError(t, err, ErrUnauthorized.Error()) + require.EqualError(t, err, ErrUnauthorized.Error()) } func TestSmokeInfoValidIP(t *testing.T) { ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler), })) + resp, err := ctiClient.GetIPInfo("1.1.1.1") if err != nil { t.Fatalf("failed to get ip info: %s", err) @@ -257,6 +274,7 @@ func TestSmokeUnknownIP(t *testing.T) { ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler), })) + resp, err := ctiClient.GetIPInfo("42.42.42.42") if err != nil { t.Fatalf("failed to get ip info: %s", err) @@ -270,20 +288,22 @@ func TestRateLimit(t *testing.T) { Transport: RoundTripFunc(rateLimitedHandler), })) _, err := ctiClient.GetIPInfo("1.1.1.1") - assert.EqualError(t, err, ErrLimit.Error()) + require.EqualError(t, err, ErrLimit.Error()) } func TestSearchIPs(t *testing.T) { ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{ Transport: RoundTripFunc(searchHandler), })) + resp, err := ctiClient.SearchIPs([]string{"1.1.1.1", "42.42.42.42"}) if err != nil { t.Fatalf("failed to search ips: %s", err) } + assert.Equal(t, 1, resp.Total) assert.Equal(t, 1, resp.NotFound) - assert.Equal(t, 1, len(resp.Items)) + assert.Len(t, resp.Items, 1) assert.Equal(t, "1.1.1.1", resp.Items[0].Ip) } diff --git a/pkg/cticlient/types_test.go b/pkg/cticlient/types_test.go index c20acc95a..a7308af35 100644 --- a/pkg/cticlient/types_test.go +++ b/pkg/cticlient/types_test.go @@ -88,27 +88,28 @@ func getSampleSmokeItem() SmokeItem { }, }, } + return emptyItem } func TestBasicSmokeItem(t *testing.T) { item := getSampleSmokeItem() - assert.Equal(t, item.GetAttackDetails(), []string{"ssh:bruteforce"}) - assert.Equal(t, item.GetBehaviors(), []string{"ssh:bruteforce"}) - assert.Equal(t, item.GetMaliciousnessScore(), float32(0.1)) - assert.Equal(t, item.IsPartOfCommunityBlocklist(), false) - assert.Equal(t, item.GetBackgroundNoiseScore(), int(3)) - assert.Equal(t, item.GetFalsePositives(), []string{}) - assert.Equal(t, item.IsFalsePositive(), false) + assert.Equal(t, []string{"ssh:bruteforce"}, item.GetAttackDetails()) + assert.Equal(t, []string{"ssh:bruteforce"}, item.GetBehaviors()) + assert.InDelta(t, 0.1, item.GetMaliciousnessScore(), 0.000001) + assert.False(t, item.IsPartOfCommunityBlocklist()) + assert.Equal(t, 3, item.GetBackgroundNoiseScore()) + assert.Equal(t, []string{}, item.GetFalsePositives()) + assert.False(t, item.IsFalsePositive()) } func TestEmptySmokeItem(t *testing.T) { item := SmokeItem{} - assert.Equal(t, item.GetAttackDetails(), []string{}) - assert.Equal(t, item.GetBehaviors(), []string{}) - assert.Equal(t, item.GetMaliciousnessScore(), float32(0.0)) - assert.Equal(t, item.IsPartOfCommunityBlocklist(), false) - assert.Equal(t, item.GetBackgroundNoiseScore(), int(0)) - assert.Equal(t, item.GetFalsePositives(), []string{}) - assert.Equal(t, item.IsFalsePositive(), false) + assert.Equal(t, []string{}, item.GetAttackDetails()) + assert.Equal(t, []string{}, item.GetBehaviors()) + assert.InDelta(t, 0.0, item.GetMaliciousnessScore(), 0) + assert.False(t, item.IsPartOfCommunityBlocklist()) + assert.Equal(t, 0, item.GetBackgroundNoiseScore()) + assert.Equal(t, []string{}, item.GetFalsePositives()) + assert.False(t, item.IsFalsePositive()) } diff --git a/pkg/cwhub/dataset.go b/pkg/cwhub/dataset.go index 06b9be1c0..c900752b8 100644 --- a/pkg/cwhub/dataset.go +++ b/pkg/cwhub/dataset.go @@ -56,6 +56,7 @@ func downloadFile(url string, destPath string) error { // if the remote has no modification date, but local file has been modified > a week ago, update. func needsUpdate(destPath string, url string, logger *logrus.Logger) bool { fileInfo, err := os.Stat(destPath) + switch { case os.IsNotExist(err): return true @@ -89,6 +90,7 @@ func needsUpdate(destPath string, url string, logger *logrus.Logger) bool { if localIsOld { logger.Infof("no last modified date for %s, but local file is older than %s", url, shelfLife) } + return localIsOld } @@ -129,6 +131,7 @@ func downloadDataSet(dataFolder string, force bool, reader io.Reader, logger *lo if force || needsUpdate(destPath, dataS.SourceURL, logger) { logger.Debugf("downloading %s in %s", dataS.SourceURL, destPath) + if err := downloadFile(dataS.SourceURL, destPath); err != nil { return fmt.Errorf("while getting data: %w", err) } diff --git a/pkg/cwhub/hub.go b/pkg/cwhub/hub.go index cb7a48e7a..21a19bc45 100644 --- a/pkg/cwhub/hub.go +++ b/pkg/cwhub/hub.go @@ -7,10 +7,10 @@ import ( "io" "os" "path" + "slices" "strings" "github.com/sirupsen/logrus" - "slices" "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) @@ -97,6 +97,10 @@ func (h *Hub) parseIndex() error { item.FileName = path.Base(item.RemotePath) item.logMissingSubItems() + + if item.latestHash() == "" { + h.logger.Errorf("invalid hub item %s: latest version missing from index", item.FQName()) + } } } diff --git a/pkg/cwhub/hub_test.go b/pkg/cwhub/hub_test.go index 17cffb57b..86569cde3 100644 --- a/pkg/cwhub/hub_test.go +++ b/pkg/cwhub/hub_test.go @@ -12,7 +12,6 @@ import ( func TestInitHubUpdate(t *testing.T) { hub := envSetup(t) - remote := &RemoteHubCfg{ URLTemplate: mockURLTemplate, Branch: "master", diff --git a/pkg/cwhub/item.go b/pkg/cwhub/item.go index e01b220fa..6c7da06c3 100644 --- a/pkg/cwhub/item.go +++ b/pkg/cwhub/item.go @@ -4,10 +4,10 @@ import ( "encoding/json" "fmt" "path/filepath" + "slices" "github.com/Masterminds/semver/v3" "github.com/enescakir/emoji" - "slices" ) const ( @@ -440,3 +440,15 @@ func (i *Item) addTaint(sub *Item) { ancestor.addTaint(sub) } } + +// latestHash() returns the hash of the latest version of the item. +// if it's missing, the index file has been manually modified or got corrupted. +func (i *Item) latestHash() string { + for k, v := range i.Versions { + if k == i.Version { + return v.Digest + } + } + + return "" +} diff --git a/pkg/cwhub/iteminstall.go b/pkg/cwhub/iteminstall.go index 3a5a0f664..ceae36491 100644 --- a/pkg/cwhub/iteminstall.go +++ b/pkg/cwhub/iteminstall.go @@ -50,7 +50,7 @@ func (i *Item) Install(force bool, downloadOnly bool) error { filePath, err := i.downloadLatest(force, true) if err != nil { - return fmt.Errorf("while downloading %s: %w", i.Name, err) + return err } if downloadOnly { diff --git a/pkg/cwhub/itemremove.go b/pkg/cwhub/itemremove.go index ccbbe3d38..eca0c8562 100644 --- a/pkg/cwhub/itemremove.go +++ b/pkg/cwhub/itemremove.go @@ -3,7 +3,6 @@ package cwhub import ( "fmt" "os" - "slices" ) diff --git a/pkg/cwhub/itemupgrade.go b/pkg/cwhub/itemupgrade.go index 073bd8797..ac3b94f98 100644 --- a/pkg/cwhub/itemupgrade.go +++ b/pkg/cwhub/itemupgrade.go @@ -6,6 +6,7 @@ import ( "bytes" "crypto/sha256" "encoding/hex" + "errors" "fmt" "io" "net/http" @@ -82,14 +83,14 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (string, error) { i.hub.logger.Tracef("collection, recurse") if _, err := sub.downloadLatest(overwrite, updateOnly); err != nil { - return "", fmt.Errorf("while downloading %s: %w", sub.Name, err) + return "", err } } downloaded := sub.State.Downloaded if _, err := sub.download(overwrite); err != nil { - return "", fmt.Errorf("while downloading %s: %w", sub.Name, err) + return "", err } // We need to enable an item when it has been added to a collection since latest release of the collection. @@ -108,7 +109,7 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (string, error) { ret, err := i.download(overwrite) if err != nil { - return "", fmt.Errorf("failed to download item: %w", err) + return "", err } return ret, nil @@ -116,6 +117,10 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (string, error) { // FetchLatest downloads the latest item from the hub, verifies the hash and returns the content and the used url. func (i *Item) FetchLatest() ([]byte, string, error) { + if i.latestHash() == "" { + return nil, "", errors.New("latest hash missing from index") + } + url, err := i.hub.remote.urlTo(i.RemotePath) if err != nil { return nil, "", fmt.Errorf("failed to build request: %w", err) @@ -146,7 +151,7 @@ func (i *Item) FetchLatest() ([]byte, string, error) { i.hub.logger.Errorf("Downloaded version doesn't match index, please 'hub update'") i.hub.logger.Debugf("got %s, expected %s", meow, i.Versions[i.Version].Digest) - return nil, "", fmt.Errorf("invalid download hash for %s", i.Name) + return nil, "", fmt.Errorf("invalid download hash") } return body, url, nil @@ -180,7 +185,12 @@ func (i *Item) download(overwrite bool) (string, error) { body, url, err := i.FetchLatest() if err != nil { - return "", fmt.Errorf("while downloading %s: %w", url, err) + what := i.Name + if url != "" { + what += " from " + url + } + + return "", fmt.Errorf("while downloading %s: %w", what, err) } // all good, install diff --git a/pkg/cwhub/itemupgrade_test.go b/pkg/cwhub/itemupgrade_test.go index 3b1fbfadd..1bd62ad63 100644 --- a/pkg/cwhub/itemupgrade_test.go +++ b/pkg/cwhub/itemupgrade_test.go @@ -12,7 +12,6 @@ import ( // We expect the new scenario to be installed. func TestUpgradeItemNewScenarioInCollection(t *testing.T) { hub := envSetup(t) - item := hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") // fresh install of collection @@ -65,7 +64,6 @@ func TestUpgradeItemNewScenarioInCollection(t *testing.T) { // Upgrade should install should not enable/download the disabled scenario. func TestUpgradeItemInDisabledScenarioShouldNotBeInstalled(t *testing.T) { hub := envSetup(t) - item := hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") // fresh install of collection @@ -127,7 +125,6 @@ func getHubOrFail(t *testing.T, local *csconfig.LocalHubCfg, remote *RemoteHubCf // Upgrade should install and enable the newly added scenario. func TestUpgradeItemNewScenarioIsInstalledWhenReferencedScenarioIsDisabled(t *testing.T) { hub := envSetup(t) - item := hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection") // fresh install of collection diff --git a/pkg/cwhub/sync.go b/pkg/cwhub/sync.go index 98b7fa167..8ce91dc21 100644 --- a/pkg/cwhub/sync.go +++ b/pkg/cwhub/sync.go @@ -7,13 +7,13 @@ import ( "io" "os" "path/filepath" + "slices" "sort" "strings" "github.com/Masterminds/semver/v3" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" - "slices" ) func isYAMLFileName(path string) bool { diff --git a/pkg/dumps/bucket_dump.go b/pkg/dumps/bucket_dump.go new file mode 100644 index 000000000..5f5ce1c40 --- /dev/null +++ b/pkg/dumps/bucket_dump.go @@ -0,0 +1,32 @@ +package dumps + +import ( + "io" + "os" + + "github.com/crowdsecurity/crowdsec/pkg/types" + "gopkg.in/yaml.v2" +) + +type BucketPourInfo map[string][]types.Event + +func LoadBucketPourDump(filepath string) (*BucketPourInfo, error) { + dumpData, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer dumpData.Close() + + results, err := io.ReadAll(dumpData) + if err != nil { + return nil, err + } + + var bucketDump BucketPourInfo + + if err := yaml.Unmarshal(results, &bucketDump); err != nil { + return nil, err + } + + return &bucketDump, nil +} diff --git a/pkg/dumps/parser_dump.go b/pkg/dumps/parser_dump.go new file mode 100644 index 000000000..566b87a08 --- /dev/null +++ b/pkg/dumps/parser_dump.go @@ -0,0 +1,319 @@ +package dumps + +import ( + "fmt" + "io" + "os" + "sort" + "strings" + "time" + + "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/crowdsecurity/go-cs-lib/maptools" + "github.com/enescakir/emoji" + "github.com/fatih/color" + diff "github.com/r3labs/diff/v2" + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" +) + +type ParserResult struct { + Idx int + Evt types.Event + Success bool +} + +type ParserResults map[string]map[string][]ParserResult + +type DumpOpts struct { + Details bool + SkipOk bool + ShowNotOkParsers bool +} + +func LoadParserDump(filepath string) (*ParserResults, error) { + dumpData, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer dumpData.Close() + + results, err := io.ReadAll(dumpData) + if err != nil { + return nil, err + } + + pdump := ParserResults{} + + if err := yaml.Unmarshal(results, &pdump); err != nil { + return nil, err + } + + /* we know that some variables should always be set, + let's check if they're present in last parser output of last stage */ + + stages := maptools.SortedKeys(pdump) + + var lastStage string + + //Loop over stages to find last successful one with at least one parser + for i := len(stages) - 2; i >= 0; i-- { + if len(pdump[stages[i]]) != 0 { + lastStage = stages[i] + break + } + } + + parsers := make([]string, 0, len(pdump[lastStage])) + + for k := range pdump[lastStage] { + parsers = append(parsers, k) + } + + sort.Strings(parsers) + + if len(parsers) == 0 { + return nil, fmt.Errorf("no parser found. Please install the appropriate parser and retry") + } + + lastParser := parsers[len(parsers)-1] + + for idx, result := range pdump[lastStage][lastParser] { + if result.Evt.StrTime == "" { + log.Warningf("Line %d/%d is missing evt.StrTime. It is most likely a mistake as it will prevent your logs to be processed in time-machine/forensic mode.", idx, len(pdump[lastStage][lastParser])) + } else { + log.Debugf("Line %d/%d has evt.StrTime set to '%s'", idx, len(pdump[lastStage][lastParser]), result.Evt.StrTime) + } + } + + return &pdump, nil +} + +func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpOpts) { + //note : we can use line -> time as the unique identifier (of acquisition) + state := make(map[time.Time]map[string]map[string]ParserResult) + assoc := make(map[time.Time]string, 0) + parser_order := make(map[string][]string) + + for stage, parsers := range parserResults { + //let's process parsers in the order according to idx + parser_order[stage] = make([]string, len(parsers)) + for pname, parser := range parsers { + if len(parser) > 0 { + parser_order[stage][parser[0].Idx-1] = pname + } + } + + for _, parser := range parser_order[stage] { + results := parsers[parser] + for _, parserRes := range results { + evt := parserRes.Evt + if _, ok := state[evt.Line.Time]; !ok { + state[evt.Line.Time] = make(map[string]map[string]ParserResult) + assoc[evt.Line.Time] = evt.Line.Raw + } + + if _, ok := state[evt.Line.Time][stage]; !ok { + state[evt.Line.Time][stage] = make(map[string]ParserResult) + } + + state[evt.Line.Time][stage][parser] = ParserResult{Evt: evt, Success: parserRes.Success} + } + } + } + + for bname, evtlist := range bucketPour { + for _, evt := range evtlist { + if evt.Line.Raw == "" { + continue + } + + //it might be bucket overflow being reprocessed, skip this + if _, ok := state[evt.Line.Time]; !ok { + state[evt.Line.Time] = make(map[string]map[string]ParserResult) + assoc[evt.Line.Time] = evt.Line.Raw + } + + //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase + //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered + if _, ok := state[evt.Line.Time]["buckets"]; !ok { + state[evt.Line.Time]["buckets"] = make(map[string]ParserResult) + } + + state[evt.Line.Time]["buckets"][bname] = ParserResult{Success: true} + } + } + + yellow := color.New(color.FgYellow).SprintFunc() + red := color.New(color.FgRed).SprintFunc() + green := color.New(color.FgGreen).SprintFunc() + whitelistReason := "" + //get each line + for tstamp, rawstr := range assoc { + if opts.SkipOk { + if _, ok := state[tstamp]["buckets"]["OK"]; ok { + continue + } + } + + fmt.Printf("line: %s\n", rawstr) + + skeys := make([]string, 0, len(state[tstamp])) + + for k := range state[tstamp] { + //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase + //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered + if k == "buckets" { + continue + } + + skeys = append(skeys, k) + } + + sort.Strings(skeys) + + // iterate stage + var prevItem types.Event + + for _, stage := range skeys { + parsers := state[tstamp][stage] + + sep := "├" + presep := "|" + + fmt.Printf("\t%s %s\n", sep, stage) + + for idx, parser := range parser_order[stage] { + res := parsers[parser].Success + sep := "├" + + if idx == len(parser_order[stage])-1 { + sep = "└" + } + + created := 0 + updated := 0 + deleted := 0 + whitelisted := false + changeStr := "" + detailsDisplay := "" + + if res { + changelog, _ := diff.Diff(prevItem, parsers[parser].Evt) + for _, change := range changelog { + switch change.Type { + case "create": + created++ + + detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), green(change.To)) + case "update": + detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s -> %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), change.From, yellow(change.To)) + + if change.Path[0] == "Whitelisted" && change.To == true { + whitelisted = true + + if whitelistReason == "" { + whitelistReason = parsers[parser].Evt.WhitelistReason + } + } + updated++ + case "delete": + deleted++ + + detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s\n", presep, sep, change.Type, red(strings.Join(change.Path, "."))) + } + } + + prevItem = parsers[parser].Evt + } + + if created > 0 { + changeStr += green(fmt.Sprintf("+%d", created)) + } + + if updated > 0 { + if len(changeStr) > 0 { + changeStr += " " + } + + changeStr += yellow(fmt.Sprintf("~%d", updated)) + } + + if deleted > 0 { + if len(changeStr) > 0 { + changeStr += " " + } + + changeStr += red(fmt.Sprintf("-%d", deleted)) + } + + if whitelisted { + if len(changeStr) > 0 { + changeStr += " " + } + + changeStr += red("[whitelisted]") + } + + if changeStr == "" { + changeStr = yellow("unchanged") + } + + if res { + fmt.Printf("\t%s\t%s %s %s (%s)\n", presep, sep, emoji.GreenCircle, parser, changeStr) + + if opts.Details { + fmt.Print(detailsDisplay) + } + } else if opts.ShowNotOkParsers { + fmt.Printf("\t%s\t%s %s %s\n", presep, sep, emoji.RedCircle, parser) + } + } + } + + sep := "└" + + if len(state[tstamp]["buckets"]) > 0 { + sep = "├" + } + + //did the event enter the bucket pour phase ? + if _, ok := state[tstamp]["buckets"]["OK"]; ok { + fmt.Printf("\t%s-------- parser success %s\n", sep, emoji.GreenCircle) + } else if whitelistReason != "" { + fmt.Printf("\t%s-------- parser success, ignored by whitelist (%s) %s\n", sep, whitelistReason, emoji.GreenCircle) + } else { + fmt.Printf("\t%s-------- parser failure %s\n", sep, emoji.RedCircle) + } + + //now print bucket info + if len(state[tstamp]["buckets"]) > 0 { + fmt.Printf("\t├ Scenarios\n") + } + + bnames := make([]string, 0, len(state[tstamp]["buckets"])) + + for k := range state[tstamp]["buckets"] { + //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase + //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered + if k == "OK" { + continue + } + + bnames = append(bnames, k) + } + + sort.Strings(bnames) + + for idx, bname := range bnames { + sep := "├" + if idx == len(bnames)-1 { + sep = "└" + } + + fmt.Printf("\t\t%s %s %s\n", sep, emoji.GreenCircle, bname) + } + + fmt.Println() + } +} diff --git a/pkg/exprhelpers/crowdsec_cti_test.go b/pkg/exprhelpers/crowdsec_cti_test.go index 80ccadba4..fc3a236c5 100644 --- a/pkg/exprhelpers/crowdsec_cti_test.go +++ b/pkg/exprhelpers/crowdsec_cti_test.go @@ -11,6 +11,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/crowdsecurity/go-cs-lib/ptr" @@ -78,6 +79,7 @@ func smokeHandler(req *http.Request) *http.Response { } requestedIP := strings.Split(req.URL.Path, "/")[3] + sample, ok := sampledata[requestedIP] if !ok { return &http.Response{ @@ -109,9 +111,11 @@ func smokeHandler(req *http.Request) *http.Response { func TestNillClient(t *testing.T) { defer ShutdownCrowdsecCTI() + if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); !errors.Is(err, cticlient.ErrDisabled) { t.Fatalf("failed to init CTI : %s", err) } + item, err := CrowdsecCTI("1.2.3.4") assert.Equal(t, err, cticlient.ErrDisabled) assert.Equal(t, item, &cticlient.SmokeItem{}) @@ -119,6 +123,7 @@ func TestNillClient(t *testing.T) { func TestInvalidAuth(t *testing.T) { defer ShutdownCrowdsecCTI() + if err := InitCrowdsecCTI(ptr.Of("asdasd"), nil, nil, nil); err != nil { t.Fatalf("failed to init CTI : %s", err) } @@ -129,7 +134,7 @@ func TestInvalidAuth(t *testing.T) { item, err := CrowdsecCTI("1.2.3.4") assert.Equal(t, item, &cticlient.SmokeItem{}) - assert.Equal(t, CTIApiEnabled, false) + assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrUnauthorized) //CTI is now disabled, all requests should return empty @@ -139,14 +144,15 @@ func TestInvalidAuth(t *testing.T) { item, err = CrowdsecCTI("1.2.3.4") assert.Equal(t, item, &cticlient.SmokeItem{}) - assert.Equal(t, CTIApiEnabled, false) + assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrDisabled) } func TestNoKey(t *testing.T) { defer ShutdownCrowdsecCTI() + err := InitCrowdsecCTI(nil, nil, nil, nil) - assert.ErrorIs(t, err, cticlient.ErrDisabled) + require.ErrorIs(t, err, cticlient.ErrDisabled) //Replace the client created by InitCrowdsecCTI with one that uses a custom transport ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey("asdasd"), cticlient.WithHTTPClient(&http.Client{ Transport: RoundTripFunc(smokeHandler), @@ -154,12 +160,13 @@ func TestNoKey(t *testing.T) { item, err := CrowdsecCTI("1.2.3.4") assert.Equal(t, item, &cticlient.SmokeItem{}) - assert.Equal(t, CTIApiEnabled, false) + assert.False(t, CTIApiEnabled) assert.Equal(t, err, cticlient.ErrDisabled) } func TestCache(t *testing.T) { defer ShutdownCrowdsecCTI() + cacheDuration := 1 * time.Second if err := InitCrowdsecCTI(ptr.Of(validApiKey), &cacheDuration, nil, nil); err != nil { t.Fatalf("failed to init CTI : %s", err) @@ -172,28 +179,27 @@ func TestCache(t *testing.T) { item, err := CrowdsecCTI("1.2.3.4") ctiResp := item.(*cticlient.SmokeItem) assert.Equal(t, "1.2.3.4", ctiResp.Ip) - assert.Equal(t, CTIApiEnabled, true) - assert.Equal(t, CTICache.Len(true), 1) - assert.Equal(t, err, nil) + assert.True(t, CTIApiEnabled) + assert.Equal(t, 1, CTICache.Len(true)) + require.NoError(t, err) item, err = CrowdsecCTI("1.2.3.4") ctiResp = item.(*cticlient.SmokeItem) assert.Equal(t, "1.2.3.4", ctiResp.Ip) - assert.Equal(t, CTIApiEnabled, true) - assert.Equal(t, CTICache.Len(true), 1) - assert.Equal(t, err, nil) + assert.True(t, CTIApiEnabled) + assert.Equal(t, 1, CTICache.Len(true)) + require.NoError(t, err) time.Sleep(2 * time.Second) - assert.Equal(t, CTICache.Len(true), 0) + assert.Equal(t, 0, CTICache.Len(true)) item, err = CrowdsecCTI("1.2.3.4") ctiResp = item.(*cticlient.SmokeItem) assert.Equal(t, "1.2.3.4", ctiResp.Ip) - assert.Equal(t, CTIApiEnabled, true) - assert.Equal(t, CTICache.Len(true), 1) - assert.Equal(t, err, nil) - + assert.True(t, CTIApiEnabled) + assert.Equal(t, 1, CTICache.Len(true)) + require.NoError(t, err) } diff --git a/pkg/exprhelpers/exprlib_test.go b/pkg/exprhelpers/exprlib_test.go index 4a6d5b74d..6b9cd15c7 100644 --- a/pkg/exprhelpers/exprlib_test.go +++ b/pkg/exprhelpers/exprlib_test.go @@ -28,17 +28,18 @@ var ( func getDBClient(t *testing.T) *database.Client { t.Helper() + dbPath, err := os.CreateTemp("", "*sqlite") require.NoError(t, err) - testDbClient, err := database.NewClient(&csconfig.DatabaseCfg{ + testDBClient, err := database.NewClient(&csconfig.DatabaseCfg{ Type: "sqlite", DbName: "crowdsec", DbPath: dbPath.Name(), }) require.NoError(t, err) - return testDbClient + return testDBClient } func TestVisitor(t *testing.T) { @@ -109,17 +110,18 @@ func TestVisitor(t *testing.T) { if err != nil && test.err == nil { log.Fatalf("run : %s", err) } + if isOk := assert.Equal(t, test.result, result); !isOk { t.Fatalf("test '%s' : NOK", test.filter) } } - } } func TestMatch(t *testing.T) { err := Init(nil) require.NoError(t, err) + tests := []struct { glob string val string @@ -149,12 +151,15 @@ func TestMatch(t *testing.T) { "pattern": test.glob, "name": test.val, } + vm, err := expr.Compile(test.expr, GetExprOptions(env)...) if err != nil { t.Fatalf("pattern:%s val:%s NOK %s", test.glob, test.val, err) } + ret, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) + if isOk := assert.Equal(t, test.ret, ret); !isOk { t.Fatalf("pattern:%s val:%s NOK %t != %t", test.glob, test.val, ret, test.ret) } @@ -194,10 +199,10 @@ func TestDistanceHelper(t *testing.T) { } ret, err := expr.Run(vm, env) if test.valid { - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.dist, ret) } else { - assert.NotNil(t, err) + require.Error(t, err) } }) } @@ -283,10 +288,12 @@ func TestRegexpInFile(t *testing.T) { if err != nil { log.Fatal(err) } + result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { log.Fatal(err) } + if isOk := assert.Equal(t, test.result, result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } @@ -335,28 +342,34 @@ func TestFileInit(t *testing.T) { if err != nil { log.Fatal(err) } - if test.types == "string" { + + switch test.types { + case "string": if _, ok := dataFile[test.filename]; !ok { t.Fatalf("test '%s' : NOK", test.name) } - if isOk := assert.Equal(t, test.result, len(dataFile[test.filename])); !isOk { + + if isOk := assert.Len(t, dataFile[test.filename], test.result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } - } else if test.types == "regex" { + case "regex": if _, ok := dataFileRegex[test.filename]; !ok { t.Fatalf("test '%s' : NOK", test.name) } - if isOk := assert.Equal(t, test.result, len(dataFileRegex[test.filename])); !isOk { + + if isOk := assert.Len(t, dataFileRegex[test.filename], test.result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } - } else { + default: if _, ok := dataFileRegex[test.filename]; ok { t.Fatalf("test '%s' : NOK", test.name) } + if _, ok := dataFile[test.filename]; ok { t.Fatalf("test '%s' : NOK", test.name) } } + log.Printf("test '%s' : OK", test.name) } } @@ -408,21 +421,23 @@ func TestFile(t *testing.T) { if err != nil { log.Fatal(err) } + result, err := expr.Run(compiledFilter, map[string]interface{}{}) if err != nil { log.Fatal(err) } + if isOk := assert.Equal(t, test.result, result); !isOk { t.Fatalf("test '%s' : NOK", test.name) } - log.Printf("test '%s' : OK", test.name) + log.Printf("test '%s' : OK", test.name) } } func TestIpInRange(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string env map[string]interface{} @@ -470,12 +485,11 @@ func TestIpInRange(t *testing.T) { require.Equal(t, test.result, output) log.Printf("test '%s' : OK", test.name) } - } func TestIpToRange(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string env map[string]interface{} @@ -543,13 +557,11 @@ func TestIpToRange(t *testing.T) { require.Equal(t, test.result, output) log.Printf("test '%s' : OK", test.name) } - } func TestAtof(t *testing.T) { - err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -593,13 +605,14 @@ func TestUpper(t *testing.T) { } err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) vm, err := expr.Compile("Upper(testStr)", GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) + v, ok := out.(string) if !ok { t.Fatalf("Upper() should return a string") @@ -612,6 +625,7 @@ func TestUpper(t *testing.T) { func TestTimeNow(t *testing.T) { now, _ := TimeNow() + ti, err := time.Parse(time.RFC3339, now.(string)) if err != nil { t.Fatalf("Error parsing the return value of TimeNow: %s", err) @@ -620,6 +634,7 @@ func TestTimeNow(t *testing.T) { if -1*time.Until(ti) > time.Second { t.Fatalf("TimeNow func should return time.Now().UTC()") } + log.Printf("test 'TimeNow()' : OK") } @@ -894,15 +909,14 @@ func TestLower(t *testing.T) { } func TestGetDecisionsCount(t *testing.T) { - var err error - var start_ip, start_sfx, end_ip, end_sfx int64 - var ip_sz int existingIP := "1.2.3.4" unknownIP := "1.2.3.5" - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP) + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) if err != nil { t.Errorf("unable to convert '%s' to int: %s", existingIP, err) } + // Add sample data to DB dbClient = getDBClient(t) @@ -921,11 +935,11 @@ func TestGetDecisionsCount(t *testing.T) { SaveX(context.Background()) if decision == nil { - assert.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.Errorf("Failed to create sample decision")) } err = Init(dbClient) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -982,12 +996,10 @@ func TestGetDecisionsCount(t *testing.T) { } } func TestGetDecisionsSinceCount(t *testing.T) { - var err error - var start_ip, start_sfx, end_ip, end_sfx int64 - var ip_sz int existingIP := "1.2.3.4" unknownIP := "1.2.3.5" - ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP) + + ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP) if err != nil { t.Errorf("unable to convert '%s' to int: %s", existingIP, err) } @@ -1008,8 +1020,9 @@ func TestGetDecisionsSinceCount(t *testing.T) { SetOrigin("CAPI"). SaveX(context.Background()) if decision == nil { - assert.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.Errorf("Failed to create sample decision")) } + decision2 := dbClient.Ent.Decision.Create(). SetCreatedAt(time.Now().AddDate(0, 0, -1)). SetUntil(time.Now().AddDate(0, 0, -1)). @@ -1024,12 +1037,13 @@ func TestGetDecisionsSinceCount(t *testing.T) { SetValue(existingIP). SetOrigin("CAPI"). SaveX(context.Background()) + if decision2 == nil { - assert.Error(t, errors.Errorf("Failed to create sample decision")) + require.Error(t, errors.Errorf("Failed to create sample decision")) } err = Init(dbClient) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -1152,6 +1166,7 @@ func TestIsIp(t *testing.T) { if err := Init(nil); err != nil { log.Fatal(err) } + tests := []struct { name string expr string @@ -1235,17 +1250,18 @@ func TestIsIp(t *testing.T) { expectedBuildErr: true, }, } + for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) if tc.expectedBuildErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) + require.NoError(t, err) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) - assert.NoError(t, err) + require.NoError(t, err) assert.IsType(t, tc.expected, output) assert.Equal(t, tc.expected, output.(bool)) }) @@ -1255,6 +1271,7 @@ func TestIsIp(t *testing.T) { func TestToString(t *testing.T) { err := Init(nil) require.NoError(t, err) + tests := []struct { name string value interface{} @@ -1290,9 +1307,9 @@ func TestToString(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) - assert.NoError(t, err) + require.NoError(t, err) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) - assert.NoError(t, err) + require.NoError(t, err) require.Equal(t, tc.expected, output) }) } @@ -1338,16 +1355,16 @@ func TestB64Decode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...) if tc.expectedBuildErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) + require.NoError(t, err) output, err := expr.Run(vm, map[string]interface{}{"value": tc.value}) if tc.expectedRuntimeErr { - assert.Error(t, err) + require.Error(t, err) return } - assert.NoError(t, err) + require.NoError(t, err) require.Equal(t, tc.expected, output) }) } @@ -1412,9 +1429,9 @@ func TestParseKv(t *testing.T) { "out": outMap, } vm, err := expr.Compile(tc.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) _, err = expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, tc.expected, outMap["a"]) }) } diff --git a/pkg/exprhelpers/jsonextract_test.go b/pkg/exprhelpers/jsonextract_test.go index 481c7d723..1bd45aa2d 100644 --- a/pkg/exprhelpers/jsonextract_test.go +++ b/pkg/exprhelpers/jsonextract_test.go @@ -7,6 +7,7 @@ import ( "github.com/antonmedv/expr" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestJsonExtract(t *testing.T) { @@ -56,14 +57,14 @@ func TestJsonExtract(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } - } + func TestJsonExtractUnescape(t *testing.T) { if err := Init(nil); err != nil { log.Fatal(err) @@ -104,9 +105,9 @@ func TestJsonExtractUnescape(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -167,9 +168,9 @@ func TestJsonExtractSlice(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -223,9 +224,9 @@ func TestJsonExtractObject(t *testing.T) { "target": test.targetField, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -233,7 +234,8 @@ func TestJsonExtractObject(t *testing.T) { func TestToJson(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) + tests := []struct { name string obj interface{} @@ -298,9 +300,9 @@ func TestToJson(t *testing.T) { "obj": test.obj, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) out, err := expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, out) }) } @@ -308,7 +310,8 @@ func TestToJson(t *testing.T) { func TestUnmarshalJSON(t *testing.T) { err := Init(nil) - assert.NoError(t, err) + require.NoError(t, err) + tests := []struct { name string json string @@ -361,11 +364,10 @@ func TestUnmarshalJSON(t *testing.T) { "out": outMap, } vm, err := expr.Compile(test.expr, GetExprOptions(env)...) - assert.NoError(t, err) + require.NoError(t, err) _, err = expr.Run(vm, env) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, test.expectResult, outMap["a"]) }) } - } diff --git a/pkg/hubtest/coverage.go b/pkg/hubtest/coverage.go index edbe10454..dc3d1d13a 100644 --- a/pkg/hubtest/coverage.go +++ b/pkg/hubtest/coverage.go @@ -9,6 +9,7 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule" "github.com/crowdsecurity/crowdsec/pkg/cwhub" + "github.com/crowdsecurity/go-cs-lib/maptools" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) @@ -25,7 +26,7 @@ func (h *HubTest) GetAppsecCoverage() ([]Coverage, error) { } // populate from hub, iterate in alphabetical order - pkeys := sortedMapKeys(h.HubIndex.GetItemMap(cwhub.APPSEC_RULES)) + pkeys := maptools.SortedKeys(h.HubIndex.GetItemMap(cwhub.APPSEC_RULES)) coverage := make([]Coverage, len(pkeys)) for i, name := range pkeys { @@ -84,7 +85,7 @@ func (h *HubTest) GetParsersCoverage() ([]Coverage, error) { } // populate from hub, iterate in alphabetical order - pkeys := sortedMapKeys(h.HubIndex.GetItemMap(cwhub.PARSERS)) + pkeys := maptools.SortedKeys(h.HubIndex.GetItemMap(cwhub.PARSERS)) coverage := make([]Coverage, len(pkeys)) for i, name := range pkeys { @@ -170,7 +171,7 @@ func (h *HubTest) GetScenariosCoverage() ([]Coverage, error) { } // populate from hub, iterate in alphabetical order - pkeys := sortedMapKeys(h.HubIndex.GetItemMap(cwhub.SCENARIOS)) + pkeys := maptools.SortedKeys(h.HubIndex.GetItemMap(cwhub.SCENARIOS)) coverage := make([]Coverage, len(pkeys)) for i, name := range pkeys { diff --git a/pkg/hubtest/parser_assert.go b/pkg/hubtest/parser_assert.go index aadf16af7..db27f710e 100644 --- a/pkg/hubtest/parser_assert.go +++ b/pkg/hubtest/parser_assert.go @@ -3,21 +3,16 @@ package hubtest import ( "bufio" "fmt" - "io" "os" - "sort" "strings" - "time" "github.com/antonmedv/expr" - "github.com/enescakir/emoji" - "github.com/fatih/color" - diff "github.com/r3labs/diff/v2" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" + "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" - "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/crowdsecurity/go-cs-lib/maptools" ) type AssertFail struct { @@ -34,16 +29,9 @@ type ParserAssert struct { NbAssert int Fails []AssertFail Success bool - TestData *ParserResults + TestData *dumps.ParserResults } -type ParserResult struct { - Evt types.Event - Success bool -} - -type ParserResults map[string]map[string][]ParserResult - func NewParserAssert(file string) *ParserAssert { ParserAssert := &ParserAssert{ File: file, @@ -51,7 +39,7 @@ func NewParserAssert(file string) *ParserAssert { Success: false, Fails: make([]AssertFail, 0), AutoGenAssert: false, - TestData: &ParserResults{}, + TestData: &dumps.ParserResults{}, } return ParserAssert @@ -69,7 +57,7 @@ func (p *ParserAssert) AutoGenFromFile(filename string) (string, error) { } func (p *ParserAssert) LoadTest(filename string) error { - parserDump, err := LoadParserDump(filename) + parserDump, err := dumps.LoadParserDump(filename) if err != nil { return fmt.Errorf("loading parser dump file: %+v", err) } @@ -229,13 +217,13 @@ func (p *ParserAssert) AutoGenParserAssert() string { ret := fmt.Sprintf("len(results) == %d\n", len(*p.TestData)) //sort map keys for consistent order - stages := sortedMapKeys(*p.TestData) + stages := maptools.SortedKeys(*p.TestData) for _, stage := range stages { parsers := (*p.TestData)[stage] //sort map keys for consistent order - pnames := sortedMapKeys(parsers) + pnames := maptools.SortedKeys(parsers) for _, parser := range pnames { presults := parsers[parser] @@ -248,7 +236,7 @@ func (p *ParserAssert) AutoGenParserAssert() string { continue } - for _, pkey := range sortedMapKeys(result.Evt.Parsed) { + for _, pkey := range maptools.SortedKeys(result.Evt.Parsed) { pval := result.Evt.Parsed[pkey] if pval == "" { continue @@ -257,7 +245,7 @@ func (p *ParserAssert) AutoGenParserAssert() string { ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Parsed["%s"] == "%s"`+"\n", stage, parser, pidx, pkey, Escape(pval)) } - for _, mkey := range sortedMapKeys(result.Evt.Meta) { + for _, mkey := range maptools.SortedKeys(result.Evt.Meta) { mval := result.Evt.Meta[mkey] if mval == "" { continue @@ -266,7 +254,7 @@ func (p *ParserAssert) AutoGenParserAssert() string { ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Meta["%s"] == "%s"`+"\n", stage, parser, pidx, mkey, Escape(mval)) } - for _, ekey := range sortedMapKeys(result.Evt.Enriched) { + for _, ekey := range maptools.SortedKeys(result.Evt.Enriched) { eval := result.Evt.Enriched[ekey] if eval == "" { continue @@ -275,7 +263,7 @@ func (p *ParserAssert) AutoGenParserAssert() string { ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Enriched["%s"] == "%s"`+"\n", stage, parser, pidx, ekey, Escape(eval)) } - for _, ukey := range sortedMapKeys(result.Evt.Unmarshaled) { + for _, ukey := range maptools.SortedKeys(result.Evt.Unmarshaled) { uval := result.Evt.Unmarshaled[ukey] if uval == "" { continue @@ -328,288 +316,3 @@ func (p *ParserAssert) buildUnmarshaledAssert(ekey string, eval interface{}) []s return ret } - -func LoadParserDump(filepath string) (*ParserResults, error) { - dumpData, err := os.Open(filepath) - if err != nil { - return nil, err - } - defer dumpData.Close() - - results, err := io.ReadAll(dumpData) - if err != nil { - return nil, err - } - - pdump := ParserResults{} - - if err := yaml.Unmarshal(results, &pdump); err != nil { - return nil, err - } - - /* we know that some variables should always be set, - let's check if they're present in last parser output of last stage */ - - stages := sortedMapKeys(pdump) - - var lastStage string - - //Loop over stages to find last successful one with at least one parser - for i := len(stages) - 2; i >= 0; i-- { - if len(pdump[stages[i]]) != 0 { - lastStage = stages[i] - break - } - } - - parsers := make([]string, 0, len(pdump[lastStage])) - - for k := range pdump[lastStage] { - parsers = append(parsers, k) - } - - sort.Strings(parsers) - - if len(parsers) == 0 { - return nil, fmt.Errorf("no parser found. Please install the appropriate parser and retry") - } - - lastParser := parsers[len(parsers)-1] - - for idx, result := range pdump[lastStage][lastParser] { - if result.Evt.StrTime == "" { - log.Warningf("Line %d/%d is missing evt.StrTime. It is most likely a mistake as it will prevent your logs to be processed in time-machine/forensic mode.", idx, len(pdump[lastStage][lastParser])) - } else { - log.Debugf("Line %d/%d has evt.StrTime set to '%s'", idx, len(pdump[lastStage][lastParser]), result.Evt.StrTime) - } - } - - return &pdump, nil -} - -type DumpOpts struct { - Details bool - SkipOk bool - ShowNotOkParsers bool -} - -func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpOpts) { - //note : we can use line -> time as the unique identifier (of acquisition) - state := make(map[time.Time]map[string]map[string]ParserResult) - assoc := make(map[time.Time]string, 0) - - for stage, parsers := range parserResults { - for parser, results := range parsers { - for _, parserRes := range results { - evt := parserRes.Evt - if _, ok := state[evt.Line.Time]; !ok { - state[evt.Line.Time] = make(map[string]map[string]ParserResult) - assoc[evt.Line.Time] = evt.Line.Raw - } - - if _, ok := state[evt.Line.Time][stage]; !ok { - state[evt.Line.Time][stage] = make(map[string]ParserResult) - } - - state[evt.Line.Time][stage][parser] = ParserResult{Evt: evt, Success: parserRes.Success} - } - } - } - - for bname, evtlist := range bucketPour { - for _, evt := range evtlist { - if evt.Line.Raw == "" { - continue - } - - //it might be bucket overflow being reprocessed, skip this - if _, ok := state[evt.Line.Time]; !ok { - state[evt.Line.Time] = make(map[string]map[string]ParserResult) - assoc[evt.Line.Time] = evt.Line.Raw - } - - //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase - //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered - if _, ok := state[evt.Line.Time]["buckets"]; !ok { - state[evt.Line.Time]["buckets"] = make(map[string]ParserResult) - } - - state[evt.Line.Time]["buckets"][bname] = ParserResult{Success: true} - } - } - - yellow := color.New(color.FgYellow).SprintFunc() - red := color.New(color.FgRed).SprintFunc() - green := color.New(color.FgGreen).SprintFunc() - whitelistReason := "" - //get each line - for tstamp, rawstr := range assoc { - if opts.SkipOk { - if _, ok := state[tstamp]["buckets"]["OK"]; ok { - continue - } - } - - fmt.Printf("line: %s\n", rawstr) - - skeys := make([]string, 0, len(state[tstamp])) - - for k := range state[tstamp] { - //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase - //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered - if k == "buckets" { - continue - } - - skeys = append(skeys, k) - } - - sort.Strings(skeys) - - // iterate stage - var prevItem types.Event - - for _, stage := range skeys { - parsers := state[tstamp][stage] - - sep := "├" - presep := "|" - - fmt.Printf("\t%s %s\n", sep, stage) - - pkeys := sortedMapKeys(parsers) - - for idx, parser := range pkeys { - res := parsers[parser].Success - sep := "├" - - if idx == len(pkeys)-1 { - sep = "└" - } - - created := 0 - updated := 0 - deleted := 0 - whitelisted := false - changeStr := "" - detailsDisplay := "" - - if res { - changelog, _ := diff.Diff(prevItem, parsers[parser].Evt) - for _, change := range changelog { - switch change.Type { - case "create": - created++ - - detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), green(change.To)) - case "update": - detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s -> %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), change.From, yellow(change.To)) - - if change.Path[0] == "Whitelisted" && change.To == true { - whitelisted = true - - if whitelistReason == "" { - whitelistReason = parsers[parser].Evt.WhitelistReason - } - } - updated++ - case "delete": - deleted++ - - detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s\n", presep, sep, change.Type, red(strings.Join(change.Path, "."))) - } - } - - prevItem = parsers[parser].Evt - } - - if created > 0 { - changeStr += green(fmt.Sprintf("+%d", created)) - } - - if updated > 0 { - if len(changeStr) > 0 { - changeStr += " " - } - - changeStr += yellow(fmt.Sprintf("~%d", updated)) - } - - if deleted > 0 { - if len(changeStr) > 0 { - changeStr += " " - } - - changeStr += red(fmt.Sprintf("-%d", deleted)) - } - - if whitelisted { - if len(changeStr) > 0 { - changeStr += " " - } - - changeStr += red("[whitelisted]") - } - - if changeStr == "" { - changeStr = yellow("unchanged") - } - - if res { - fmt.Printf("\t%s\t%s %s %s (%s)\n", presep, sep, emoji.GreenCircle, parser, changeStr) - - if opts.Details { - fmt.Print(detailsDisplay) - } - } else if opts.ShowNotOkParsers { - fmt.Printf("\t%s\t%s %s %s\n", presep, sep, emoji.RedCircle, parser) - } - } - } - - sep := "└" - - if len(state[tstamp]["buckets"]) > 0 { - sep = "├" - } - - //did the event enter the bucket pour phase ? - if _, ok := state[tstamp]["buckets"]["OK"]; ok { - fmt.Printf("\t%s-------- parser success %s\n", sep, emoji.GreenCircle) - } else if whitelistReason != "" { - fmt.Printf("\t%s-------- parser success, ignored by whitelist (%s) %s\n", sep, whitelistReason, emoji.GreenCircle) - } else { - fmt.Printf("\t%s-------- parser failure %s\n", sep, emoji.RedCircle) - } - - //now print bucket info - if len(state[tstamp]["buckets"]) > 0 { - fmt.Printf("\t├ Scenarios\n") - } - - bnames := make([]string, 0, len(state[tstamp]["buckets"])) - - for k := range state[tstamp]["buckets"] { - //there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase - //we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered - if k == "OK" { - continue - } - - bnames = append(bnames, k) - } - - sort.Strings(bnames) - - for idx, bname := range bnames { - sep := "├" - if idx == len(bnames)-1 { - sep = "└" - } - - fmt.Printf("\t\t%s %s %s\n", sep, emoji.GreenCircle, bname) - } - - fmt.Println() - } -} diff --git a/pkg/hubtest/scenario_assert.go b/pkg/hubtest/scenario_assert.go index 011d3dcfb..5195b814e 100644 --- a/pkg/hubtest/scenario_assert.go +++ b/pkg/hubtest/scenario_assert.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" + "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -24,11 +25,10 @@ type ScenarioAssert struct { Fails []AssertFail Success bool TestData *BucketResults - PourData *BucketPourInfo + PourData *dumps.BucketPourInfo } type BucketResults []types.Event -type BucketPourInfo map[string][]types.Event func NewScenarioAssert(file string) *ScenarioAssert { ScenarioAssert := &ScenarioAssert{ @@ -38,7 +38,7 @@ func NewScenarioAssert(file string) *ScenarioAssert { Fails: make([]AssertFail, 0), AutoGenAssert: false, TestData: &BucketResults{}, - PourData: &BucketPourInfo{}, + PourData: &dumps.BucketPourInfo{}, } return ScenarioAssert @@ -64,7 +64,7 @@ func (s *ScenarioAssert) LoadTest(filename string, bucketpour string) error { s.TestData = bucketDump if bucketpour != "" { - pourDump, err := LoadBucketPourDump(bucketpour) + pourDump, err := dumps.LoadBucketPourDump(bucketpour) if err != nil { return fmt.Errorf("loading bucket pour dump file '%s': %+v", filename, err) } @@ -252,27 +252,6 @@ func (b BucketResults) Swap(i, j int) { b[i], b[j] = b[j], b[i] } -func LoadBucketPourDump(filepath string) (*BucketPourInfo, error) { - dumpData, err := os.Open(filepath) - if err != nil { - return nil, err - } - defer dumpData.Close() - - results, err := io.ReadAll(dumpData) - if err != nil { - return nil, err - } - - var bucketDump BucketPourInfo - - if err := yaml.Unmarshal(results, &bucketDump); err != nil { - return nil, err - } - - return &bucketDump, nil -} - func LoadScenarioDump(filepath string) (*BucketResults, error) { dumpData, err := os.Open(filepath) if err != nil { diff --git a/pkg/hubtest/utils.go b/pkg/hubtest/utils.go index 6c48cb3a6..9009d0ddd 100644 --- a/pkg/hubtest/utils.go +++ b/pkg/hubtest/utils.go @@ -5,21 +5,25 @@ import ( "net" "os" "path/filepath" - "sort" "time" log "github.com/sirupsen/logrus" ) -func sortedMapKeys[V any](m map[string]V) []string { - keys := make([]string, 0, len(m)) - for k := range m { - keys = append(keys, k) +func IsAlive(target string) (bool, error) { + start := time.Now() + for { + conn, err := net.Dial("tcp", target) + if err == nil { + log.Debugf("'%s' is up after %s", target, time.Since(start)) + conn.Close() + return true, nil + } + time.Sleep(500 * time.Millisecond) + if time.Since(start) > 10*time.Second { + return false, fmt.Errorf("took more than 10s for %s to be available", target) + } } - - sort.Strings(keys) - - return keys } func Copy(src string, dst string) error { @@ -110,19 +114,3 @@ func CopyDir(src string, dest string) error { return nil } - -func IsAlive(target string) (bool, error) { - start := time.Now() - for { - conn, err := net.Dial("tcp", target) - if err == nil { - log.Debugf("'%s' is up after %s", target, time.Since(start)) - conn.Close() - return true, nil - } - time.Sleep(500 * time.Millisecond) - if time.Since(start) > 10*time.Second { - return false, fmt.Errorf("took more than 10s for %s to be available", target) - } - } -} diff --git a/pkg/parser/runtime.go b/pkg/parser/runtime.go index e1b33bc6e..693fb1e7d 100644 --- a/pkg/parser/runtime.go +++ b/pkg/parser/runtime.go @@ -18,6 +18,7 @@ import ( "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" + "github.com/crowdsecurity/crowdsec/pkg/dumps" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/types" ) @@ -229,14 +230,10 @@ func stageidx(stage string, stages []string) int { return -1 } -type ParserResult struct { - Evt types.Event - Success bool -} - var ParseDump bool var DumpFolder string -var StageParseCache map[string]map[string][]ParserResult + +var StageParseCache dumps.ParserResults var StageParseMutex sync.Mutex func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) { @@ -271,9 +268,9 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) if ParseDump { if StageParseCache == nil { StageParseMutex.Lock() - StageParseCache = make(map[string]map[string][]ParserResult) - StageParseCache["success"] = make(map[string][]ParserResult) - StageParseCache["success"][""] = make([]ParserResult, 0) + StageParseCache = make(dumps.ParserResults) + StageParseCache["success"] = make(map[string][]dumps.ParserResult) + StageParseCache["success"][""] = make([]dumps.ParserResult, 0) StageParseMutex.Unlock() } } @@ -282,7 +279,7 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) if ParseDump { StageParseMutex.Lock() if _, ok := StageParseCache[stage]; !ok { - StageParseCache[stage] = make(map[string][]ParserResult) + StageParseCache[stage] = make(map[string][]dumps.ParserResult) } StageParseMutex.Unlock() } @@ -322,13 +319,18 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) } clog.Tracef("node (%s) ret : %v", node.rn, ret) if ParseDump { + parserIdxInStage := 0 StageParseMutex.Lock() if len(StageParseCache[stage][node.Name]) == 0 { - StageParseCache[stage][node.Name] = make([]ParserResult, 0) + StageParseCache[stage][node.Name] = make([]dumps.ParserResult, 0) + parserIdxInStage = len(StageParseCache[stage]) + } else { + parserIdxInStage = StageParseCache[stage][node.Name][0].Idx } StageParseMutex.Unlock() + evtcopy := deepcopy.Copy(event) - parserInfo := ParserResult{Evt: evtcopy.(types.Event), Success: ret} + parserInfo := dumps.ParserResult{Evt: evtcopy.(types.Event), Success: ret, Idx: parserIdxInStage} StageParseMutex.Lock() StageParseCache[stage][node.Name] = append(StageParseCache[stage][node.Name], parserInfo) StageParseMutex.Unlock() diff --git a/pkg/setup/detect_test.go b/pkg/setup/detect_test.go index 98a22db0d..242ade049 100644 --- a/pkg/setup/detect_test.go +++ b/pkg/setup/detect_test.go @@ -353,7 +353,7 @@ func TestUnitFound(t *testing.T) { installed, err := env.UnitFound("crowdsec-setup-detect.service") require.NoError(err) - require.Equal(true, installed) + require.True(installed) } // TODO apply rules to filter a list of Service structs @@ -566,8 +566,8 @@ func TestDetectForcedUnit(t *testing.T) { func TestDetectForcedProcess(t *testing.T) { if runtime.GOOS == "windows" { - t.Skip("skipping on windows") // while looking for service wizard: rule 'ProcessRunning("foobar")': while looking up running processes: could not get Name: A device attached to the system is not functioning. + t.Skip("skipping on windows") } require := require.New(t) diff --git a/pkg/types/event_test.go b/pkg/types/event_test.go index c3261c647..14ca48cd2 100644 --- a/pkg/types/event_test.go +++ b/pkg/types/event_test.go @@ -73,7 +73,7 @@ func TestParseIPSources(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { ips := tt.evt.ParseIPSources() - assert.Equal(t, ips, tt.expected) + assert.Equal(t, tt.expected, ips) }) } } diff --git a/pkg/types/utils.go b/pkg/types/utils.go index b58eda99b..e42c36d8a 100644 --- a/pkg/types/utils.go +++ b/pkg/types/utils.go @@ -46,7 +46,7 @@ func SetDefaultLoggerConfig(cfgMode string, cfgFolder string, cfgLevel log.Level } logLevel = cfgLevel log.SetLevel(logLevel) - logFormatter = &log.TextFormatter{TimestampFormat: "2006-01-02 15:04:05", FullTimestamp: true, ForceColors: forceColors} + logFormatter = &log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true, ForceColors: forceColors} log.SetFormatter(logFormatter) return nil } diff --git a/test/bats/04_capi.bats b/test/bats/04_capi.bats index 04d3a1aa0..64da27a56 100644 --- a/test/bats/04_capi.bats +++ b/test/bats/04_capi.bats @@ -19,7 +19,35 @@ setup() { #---------- +@test "cscli capi status: fails without credentials" { + config_enable_capi + ONLINE_API_CREDENTIALS_YAML="$(config_get '.api.server.online_client.credentials_path')" + # bogus values, won't be used + echo '{"login":"login","password":"password","url":"url"}' > "${ONLINE_API_CREDENTIALS_YAML}" + + config_set "$ONLINE_API_CREDENTIALS_YAML" 'del(.url)' + rune -1 cscli capi status + assert_stderr --partial "can't load CAPI credentials from '$ONLINE_API_CREDENTIALS_YAML' (missing url field)" + + config_set "$ONLINE_API_CREDENTIALS_YAML" 'del(.password)' + rune -1 cscli capi status + assert_stderr --partial "can't load CAPI credentials from '$ONLINE_API_CREDENTIALS_YAML' (missing password field)" + + config_set "$ONLINE_API_CREDENTIALS_YAML" 'del(.login)' + rune -1 cscli capi status + assert_stderr --partial "can't load CAPI credentials from '$ONLINE_API_CREDENTIALS_YAML' (missing login field)" + + rm "${ONLINE_API_CREDENTIALS_YAML}" + rune -1 cscli capi status + assert_stderr --partial "failed to load Local API: loading online client credentials: open ${ONLINE_API_CREDENTIALS_YAML}: no such file or directory" + + config_set 'del(.api.server.online_client)' + rune -1 cscli capi status + assert_stderr --partial "no configuration for Central API (CAPI) in '$CONFIG_YAML'" +} + @test "cscli capi status" { + ./instance-data load config_enable_capi rune -0 cscli capi register --schmilblick githubciXXXXXXXXXXXXXXXXXXXXXXXX rune -1 cscli capi status @@ -60,13 +88,6 @@ setup() { assert_stderr --partial "You can successfully interact with Central API (CAPI)" } -@test "cscli capi status: fails without credentials" { - ONLINE_API_CREDENTIALS_YAML="$(config_get '.api.server.online_client.credentials_path')" - rm "${ONLINE_API_CREDENTIALS_YAML}" - rune -1 cscli capi status - assert_stderr --partial "failed to load Local API: loading online client credentials: failed to read api server credentials configuration file '${ONLINE_API_CREDENTIALS_YAML}': open ${ONLINE_API_CREDENTIALS_YAML}: no such file or directory" -} - @test "capi register must be run from lapi" { config_disable_lapi rune -1 cscli capi register --schmilblick githubciXXXXXXXXXXXXXXXXXXXXXXXX diff --git a/test/bats/09_context.bats b/test/bats/09_context.bats index 6163d53f9..ba2954510 100644 --- a/test/bats/09_context.bats +++ b/test/bats/09_context.bats @@ -46,9 +46,21 @@ teardown() { assert_stderr --partial "loading console context from $CONTEXT_YAML" } +@test "no error if context file is missing but not explicitly set" { + config_set "del(.crowdsec_service.console_context_path)" + rune -0 rm -f "$CONTEXT_YAML" + rune -0 cscli lapi context status --error + refute_stderr + assert_output --partial "No context found on this agent." + rune -0 "$CROWDSEC" -t + refute_stderr --partial "no such file or directory" +} + @test "error if context file is explicitly set but does not exist" { config_set ".crowdsec_service.console_context_path=strenv(CONTEXT_YAML)" rune -0 rm -f "$CONTEXT_YAML" + rune -1 cscli lapi context status --error + assert_stderr --partial "context.yaml: no such file or directory" rune -1 "$CROWDSEC" -t assert_stderr --partial "while checking console_context_path: stat $CONTEXT_YAML: no such file or directory" } diff --git a/test/bats/13_capi_whitelists.bats b/test/bats/13_capi_whitelists.bats index 61d0e641c..d05a9d932 100644 --- a/test/bats/13_capi_whitelists.bats +++ b/test/bats/13_capi_whitelists.bats @@ -30,7 +30,7 @@ teardown() { @test "capi_whitelists: file missing" { rune -0 wait-for \ - --err "capi whitelist file '$CAPI_WHITELISTS_YAML' does not exist" \ + --err "while opening capi whitelist file: open $CAPI_WHITELISTS_YAML: no such file or directory" \ "${CROWDSEC}" } diff --git a/test/bats/20_hub.bats b/test/bats/20_hub.bats index 13b9ac3e6..18e3770bc 100644 --- a/test/bats/20_hub.bats +++ b/test/bats/20_hub.bats @@ -69,6 +69,16 @@ teardown() { assert_output --partial 'crowdsecurity/iptables' } +@test "cscli hub list (invalid index)" { + new_hub=$(jq <"$INDEX_PATH" '."appsec-rules"."crowdsecurity/vpatch-laravel-debug-mode".version="999"') + echo "$new_hub" >"$INDEX_PATH" + rune -0 cscli hub list --error + assert_stderr --partial "invalid hub item appsec-rules:crowdsecurity/vpatch-laravel-debug-mode: latest version missing from index" + + rune -1 cscli appsec-rules install crowdsecurity/vpatch-laravel-debug-mode --force + assert_stderr --partial "error while installing 'crowdsecurity/vpatch-laravel-debug-mode': while downloading crowdsecurity/vpatch-laravel-debug-mode: latest hash missing from index" +} + @test "missing reference in hub index" { new_hub=$(jq <"$INDEX_PATH" 'del(.parsers."crowdsecurity/smb-logs") | del (.scenarios."crowdsecurity/mysql-bf")') echo "$new_hub" >"$INDEX_PATH" diff --git a/wizard.sh b/wizard.sh index da15b2aa3..f9622cd1b 100755 --- a/wizard.sh +++ b/wizard.sh @@ -102,7 +102,7 @@ log_info() { log_fatal() { msg=$1 date=$(date "+%Y-%m-%d %H:%M:%S") - echo -e "${RED}FATA${NC}[${date}] crowdsec_wizard: ${msg}" 1>&2 + echo -e "${RED}FATA${NC}[${date}] crowdsec_wizard: ${msg}" 1>&2 exit 1 } @@ -129,16 +129,16 @@ log_dbg() { detect_services () { DETECTED_SERVICES=() HMENU=() - #list systemd services + # list systemd services SYSTEMD_SERVICES=`systemctl --state=enabled list-unit-files '*.service' | cut -d ' ' -f1` - #raw ps + # raw ps PSAX=`ps ax -o comm=` for SVC in ${SUPPORTED_SERVICES} ; do log_dbg "Checking if service '${SVC}' is running (ps+systemd)" for SRC in "${SYSTEMD_SERVICES}" "${PSAX}" ; do echo ${SRC} | grep ${SVC} >/dev/null if [ $? -eq 0 ]; then - #on centos, apache2 is named httpd + # on centos, apache2 is named httpd if [[ ${SVC} == "httpd" ]] ; then SVC="apache2"; fi @@ -152,12 +152,12 @@ detect_services () { if [[ ${OSTYPE} == "linux-gnu" ]] || [[ ${OSTYPE} == "linux-gnueabihf" ]]; then DETECTED_SERVICES+=("linux") HMENU+=("linux" "on") - else + else log_info "NOT A LINUX" fi; if [[ ${SILENT} == "false" ]]; then - #we put whiptail results in an array, notice the dark magic fd redirection + # we put whiptail results in an array, notice the dark magic fd redirection DETECTED_SERVICES=($(whiptail --separate-output --noitem --ok-button Continue --title "Services to monitor" --checklist "Detected services, uncheck to ignore. Ignored services won't be monitored." 18 70 10 ${HMENU[@]} 3>&1 1>&2 2>&3)) if [ $? -eq 1 ]; then log_err "user bailed out at services selection" @@ -189,28 +189,28 @@ log_locations[mysql]='/var/log/mysql/error.log' log_locations[smb]='/var/log/samba*.log' log_locations[linux]='/var/log/syslog,/var/log/kern.log,/var/log/messages' -#$1 is service name, such those in SUPPORTED_SERVICES +# $1 is service name, such those in SUPPORTED_SERVICES find_logs_for() { ret="" x=${1} - #we have trailing and starting quotes because of whiptail + # we have trailing and starting quotes because of whiptail SVC="${x%\"}" SVC="${SVC#\"}" DETECTED_LOGFILES=() HMENU=() - #log_info "Searching logs for ${SVC} : ${log_locations[${SVC}]}" + # log_info "Searching logs for ${SVC} : ${log_locations[${SVC}]}" - #split the line into an array with ',' separator + # split the line into an array with ',' separator OIFS=${IFS} IFS=',' read -r -a a <<< "${log_locations[${SVC}]}," IFS=${OIFS} - #readarray -td, a <<<"${log_locations[${SVC}]},"; unset 'a[-1]'; + # readarray -td, a <<<"${log_locations[${SVC}]},"; unset 'a[-1]'; for poss_path in "${a[@]}"; do - #Split /var/log/nginx/*.log into '/var/log/nginx' and '*.log' so we can use find + # Split /var/log/nginx/*.log into '/var/log/nginx' and '*.log' so we can use find path=${poss_path%/*} fname=${poss_path##*/} candidates=`find "${path}" -type f -mtime -5 -ctime -5 -name "$fname"` - #We have some candidates, add them + # We have some candidates, add them for final_file in ${candidates} ; do log_dbg "Found logs file for '${SVC}': ${final_file}" DETECTED_LOGFILES+=(${final_file}) @@ -249,12 +249,12 @@ install_collection() { in_array $collection "${DETECTED_SERVICES[@]}" if [[ $? == 0 ]]; then HMENU+=("${collection}" "${description}" "ON") - #in case we're not in interactive mode, assume defaults + # in case we're not in interactive mode, assume defaults COLLECTION_TO_INSTALL+=(${collection}) else if [[ ${collection} == "linux" ]]; then HMENU+=("${collection}" "${description}" "ON") - #in case we're not in interactive mode, assume defaults + # in case we're not in interactive mode, assume defaults COLLECTION_TO_INSTALL+=(${collection}) else HMENU+=("${collection}" "${description}" "OFF") @@ -272,10 +272,10 @@ install_collection() { for collection in "${COLLECTION_TO_INSTALL[@]}"; do log_info "Installing collection '${collection}'" - ${CSCLI_BIN_INSTALLED} collections install "${collection}" > /dev/null 2>&1 || log_err "fail to install collection ${collection}" + ${CSCLI_BIN_INSTALLED} collections install "${collection}" --error done - ${CSCLI_BIN_INSTALLED} parsers install "crowdsecurity/whitelists" > /dev/null 2>&1 || log_err "fail to install collection crowdsec/whitelists" + ${CSCLI_BIN_INSTALLED} parsers install "crowdsecurity/whitelists" --error if [[ ${SILENT} == "false" ]]; then whiptail --msgbox "Out of safety, I installed a parser called 'crowdsecurity/whitelists'. This one will prevent private IP addresses from being banned, feel free to remove it any time." 20 50 fi @@ -285,14 +285,14 @@ install_collection() { fi } -#$1 is the service name, $... is the list of candidate logs (from find_logs_for) +# $1 is the service name, $... is the list of candidate logs (from find_logs_for) genyamllog() { local service="${1}" shift local files=("${@}") - + echo "#Generated acquisition file - wizard.sh (service: ${service}) / files : ${files[@]}" >> ${TMP_ACQUIS_FILE} - + echo "filenames:" >> ${TMP_ACQUIS_FILE} for fd in ${files[@]}; do echo " - ${fd}" >> ${TMP_ACQUIS_FILE} @@ -306,9 +306,9 @@ genyamllog() { genyamljournal() { local service="${1}" shift - + echo "#Generated acquisition file - wizard.sh (service: ${service}) / files : ${files[@]}" >> ${TMP_ACQUIS_FILE} - + echo "journalctl_filter:" >> ${TMP_ACQUIS_FILE} echo " - _SYSTEMD_UNIT="${service}".service" >> ${TMP_ACQUIS_FILE} echo "labels:" >> ${TMP_ACQUIS_FILE} @@ -318,7 +318,7 @@ genyamljournal() { } genacquisition() { - if skip_tmp_acquis; then + if skip_tmp_acquis; then TMP_ACQUIS_FILE="${ACQUIS_TARGET}" ACQUIS_FILE_MSG="acquisition file generated to: ${TMP_ACQUIS_FILE}" else @@ -336,7 +336,7 @@ genacquisition() { log_info "using journald for '${PSVG}'" genyamljournal ${PSVG} fi; - done + done } detect_cs_install () { @@ -371,7 +371,7 @@ check_cs_version () { fi elif [[ $NEW_MINOR_VERSION -gt $CURRENT_MINOR_VERSION ]] ; then log_warn "new version ($NEW_CS_VERSION) is a minor upgrade !" - if [[ $ACTION != "upgrade" ]] ; then + if [[ $ACTION != "upgrade" ]] ; then if [[ ${FORCE_MODE} == "false" ]]; then echo "" echo "We recommend to upgrade with : sudo ./wizard.sh --upgrade " @@ -383,7 +383,7 @@ check_cs_version () { fi elif [[ $NEW_PATCH_VERSION -gt $CURRENT_PATCH_VERSION ]] ; then log_warn "new version ($NEW_CS_VERSION) is a patch !" - if [[ $ACTION != "binupgrade" ]] ; then + if [[ $ACTION != "binupgrade" ]] ; then if [[ ${FORCE_MODE} == "false" ]]; then echo "" echo "We recommend to upgrade binaries only : sudo ./wizard.sh --binupgrade " @@ -406,7 +406,7 @@ check_cs_version () { fi } -#install crowdsec and cscli +# install crowdsec and cscli install_crowdsec() { mkdir -p "${CROWDSEC_DATA_DIR}" (cd config && find patterns -type f -exec install -Dm 644 "{}" "${CROWDSEC_CONFIG_PATH}/{}" \; && cd ../) || exit @@ -418,7 +418,7 @@ install_crowdsec() { mkdir -p "${CROWDSEC_CONFIG_PATH}/appsec-rules" || exit mkdir -p "${CROWDSEC_CONSOLE_DIR}" || exit - #tmp + # tmp mkdir -p /tmp/data mkdir -p /etc/crowdsec/hub/ install -v -m 600 -D "./config/${CLIENT_SECRETS}" "${CROWDSEC_CONFIG_PATH}" 1> /dev/null || exit @@ -490,7 +490,7 @@ install_bins() { install -v -m 755 -D "${CSCLI_BIN}" "${CSCLI_BIN_INSTALLED}" 1> /dev/null || exit which systemctl && systemctl is-active --quiet crowdsec if [ $? -eq 0 ]; then - systemctl stop crowdsec + systemctl stop crowdsec fi install_plugins symlink_bins @@ -508,7 +508,7 @@ symlink_bins() { delete_bins() { log_info "Removing crowdsec binaries" rm -f ${CROWDSEC_BIN_INSTALLED} - rm -f ${CSCLI_BIN_INSTALLED} + rm -f ${CSCLI_BIN_INSTALLED} } delete_plugins() { @@ -535,7 +535,7 @@ install_plugins(){ } check_running_bouncers() { - #when uninstalling, check if user still has bouncers + # when uninstalling, check if user still has bouncers BOUNCERS_COUNT=$(${CSCLI_BIN} bouncers list -o=raw | tail -n +2 | wc -l) if [[ ${BOUNCERS_COUNT} -gt 0 ]] ; then if [[ ${FORCE_MODE} == "false" ]]; then @@ -646,7 +646,7 @@ main() { then return fi - + if [[ "$1" == "uninstall" ]]; then if ! [ $(id -u) = 0 ]; then @@ -685,11 +685,11 @@ main() { log_info "installing crowdsec" install_crowdsec log_dbg "configuring ${CSCLI_BIN_INSTALLED}" - ${CSCLI_BIN_INSTALLED} hub update > /dev/null 2>&1 || (log_err "fail to update crowdsec hub. exiting" && exit 1) + ${CSCLI_BIN_INSTALLED} hub update --error || (log_err "fail to update crowdsec hub. exiting" && exit 1) # detect running services detect_services - if ! [ ${#DETECTED_SERVICES[@]} -gt 0 ] ; then + if ! [ ${#DETECTED_SERVICES[@]} -gt 0 ] ; then log_err "No detected or selected services, stopping." exit 1 fi; @@ -711,11 +711,11 @@ main() { # api register ${CSCLI_BIN_INSTALLED} machines add --force "$(cat /etc/machine-id)" -a -f "${CROWDSEC_CONFIG_PATH}/${CLIENT_SECRETS}" || log_fatal "unable to add machine to the local API" - log_dbg "Crowdsec LAPI registered" - + log_dbg "Crowdsec LAPI registered" + ${CSCLI_BIN_INSTALLED} capi register || log_fatal "unable to register to the Central API" - log_dbg "Crowdsec CAPI registered" - + log_dbg "Crowdsec CAPI registered" + systemctl enable -q crowdsec >/dev/null || log_fatal "unable to enable crowdsec" systemctl start crowdsec >/dev/null || log_fatal "unable to start crowdsec" log_info "enabling and starting crowdsec daemon" @@ -729,7 +729,7 @@ main() { rm -f "${TMP_ACQUIS_FILE}" fi detect_services - if [[ ${DETECTED_SERVICES} == "" ]] ; then + if [[ ${DETECTED_SERVICES} == "" ]] ; then log_err "No detected or selected services, stopping." exit fi; @@ -757,7 +757,7 @@ usage() { echo " ./wizard.sh --docker-mode Will install crowdsec without systemd and generate random machine-id" echo " ./wizard.sh -n|--noop Do nothing" - exit 0 + exit 0 } if [[ $# -eq 0 ]]; then @@ -770,15 +770,15 @@ do case ${key} in --uninstall) ACTION="uninstall" - shift #past argument + shift # past argument ;; --binupgrade) ACTION="binupgrade" - shift #past argument + shift # past argument ;; --upgrade) ACTION="upgrade" - shift #past argument + shift # past argument ;; -i|--install) ACTION="install" @@ -813,11 +813,11 @@ do -f|--force) FORCE_MODE="true" shift - ;; + ;; -v|--verbose) DEBUG_MODE="true" shift - ;; + ;; -h|--help) usage exit 0