Merge branch 'master' into enable_context_in_console

This commit is contained in:
Thibault "bui" Koechlin 2024-01-15 11:16:30 +01:00 committed by GitHub
commit 9a0d13c3c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
98 changed files with 1769 additions and 1708 deletions

View file

@ -36,7 +36,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v4
with:
go-version: "1.21.5"
go-version: "1.21.6"
- name: "Install bats dependencies"
env:

View file

@ -14,7 +14,7 @@ jobs:
build:
strategy:
matrix:
go-version: ["1.21.5"]
go-version: ["1.21.6"]
name: "Build + tests"
runs-on: ubuntu-latest

View file

@ -10,7 +10,7 @@ jobs:
build:
strategy:
matrix:
go-version: ["1.21.5"]
go-version: ["1.21.6"]
name: "Build + tests"
runs-on: ubuntu-latest

View file

@ -11,7 +11,7 @@ jobs:
build:
strategy:
matrix:
go-version: ["1.21.5"]
go-version: ["1.21.6"]
name: "Build + tests"
runs-on: ubuntu-latest

View file

@ -23,7 +23,7 @@ jobs:
build:
strategy:
matrix:
go-version: ["1.21.5"]
go-version: ["1.21.6"]
name: Build
runs-on: windows-2019

View file

@ -74,7 +74,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v4
with:
go-version: "1.21.0"
go-version: "1.21.6"
cache-dependency-path: "**/go.sum"
- run: |

View file

@ -22,7 +22,7 @@ jobs:
build:
strategy:
matrix:
go-version: ["1.21.5"]
go-version: ["1.21.6"]
name: "Build + tests"
runs-on: windows-2022

View file

@ -126,7 +126,7 @@ jobs:
- name: "Set up Go"
uses: actions/setup-go@v4
with:
go-version: "1.21.5"
go-version: "1.21.6"
- name: Create localstack streams
run: |

View file

@ -19,6 +19,7 @@ jobs:
push_to_registry:
name: Push Debian Docker image to Docker Hub
runs-on: ubuntu-latest
if: ${{ github.repository_owner == 'crowdsecurity' }}
steps:
- name: Check out the repo

View file

@ -19,6 +19,7 @@ jobs:
push_to_registry:
name: Push Docker image to Docker Hub
runs-on: ubuntu-latest
if: ${{ github.repository_owner == 'crowdsecurity' }}
steps:
- name: Check out the repo

View file

@ -14,7 +14,7 @@ jobs:
build:
strategy:
matrix:
go-version: ["1.21.5"]
go-version: ["1.21.6"]
name: Build and upload binary package
runs-on: ubuntu-latest

View file

@ -9,6 +9,13 @@ run:
- pkg/yamlpatch/merge_test.go
linters-settings:
gci:
sections:
- standard
- default
- prefix(github.com/crowdsecurity)
- prefix(github.com/crowdsecurity/crowdsec)
gocyclo:
min-complexity: 30
@ -235,42 +242,6 @@ issues:
# Will fix, trivial - just beware of merge conflicts
- linters:
- testifylint
text: "expected-actual: need to reverse actual and expected values"
- linters:
- testifylint
text: "bool-compare: use assert.False"
- linters:
- testifylint
text: "len: use assert.Len"
- linters:
- testifylint
text: "bool-compare: use assert.True"
- linters:
- testifylint
text: "bool-compare: use require.True"
- linters:
- testifylint
text: "require-error: for error assertions use require"
- linters:
- testifylint
text: "error-nil: use assert.NoError"
- linters:
- testifylint
text: "error-nil: use assert.Error"
- linters:
- testifylint
text: "empty: use assert.Empty"
- linters:
- perfsprint
text: "fmt.Sprintf can be replaced .*"
@ -279,10 +250,6 @@ issues:
# Will fix, easy but some neurons required
#
- linters:
- testifylint
text: "float-compare: use assert.InEpsilon .*or InDelta.*"
- linters:
- errorlint
text: "non-wrapping format verb for fmt.Errorf. Use `%w` to format errors"

View file

@ -1,5 +1,6 @@
# vim: set ft=dockerfile:
ARG GOVERSION=1.21.5
ARG GOVERSION=1.21.6
ARG BUILD_VERSION
FROM golang:${GOVERSION}-alpine3.18 AS build
@ -7,6 +8,7 @@ WORKDIR /go/src/crowdsec
# We like to choose the release of re2 to use, and Alpine does not ship a static version anyway.
ENV RE2_VERSION=2023-03-01
ENV BUILD_VERSION=${BUILD_VERSION}
# wizard.sh requires GNU coreutils
RUN apk add --no-cache git g++ gcc libc-dev make bash gettext binutils-gold coreutils pkgconfig && \

View file

@ -1,5 +1,6 @@
# vim: set ft=dockerfile:
ARG GOVERSION=1.21.5
ARG GOVERSION=1.21.6
ARG BUILD_VERSION
FROM golang:${GOVERSION}-bookworm AS build
@ -10,6 +11,7 @@ ENV DEBCONF_NOWARNINGS="yes"
# We like to choose the release of re2 to use, the debian version is usually older.
ENV RE2_VERSION=2023-03-01
ENV BUILD_VERSION=${BUILD_VERSION}
# wizard.sh requires GNU coreutils
RUN apt-get update && \

View file

@ -25,9 +25,9 @@ stages:
custom: 'tool'
arguments: 'install --global SignClient --version 1.3.155'
- task: GoTool@0
displayName: "Install Go 1.20"
displayName: "Install Go"
inputs:
version: '1.21.5'
version: '1.21.6'
- pwsh: |
choco install -y make

View file

@ -38,7 +38,7 @@ func (cli cliCapi) NewCommand() *cobra.Command {
Short: "Manage interaction with Central API (CAPI)",
Args: cobra.MinimumNArgs(1),
DisableAutoGenTag: true,
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
if err := require.LAPI(csConfig); err != nil {
return err
}
@ -58,15 +58,17 @@ func (cli cliCapi) NewCommand() *cobra.Command {
}
func (cli cliCapi) NewRegisterCmd() *cobra.Command {
var capiUserPrefix string
var outputFile string
var (
capiUserPrefix string
outputFile string
)
var cmd = &cobra.Command{
Use: "register",
Short: "Register to Central API (CAPI)",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(_ *cobra.Command, _ []string) error {
var err error
capiUser, err := generateID(capiUserPrefix)
if err != nil {
@ -115,7 +117,7 @@ func (cli cliCapi) NewRegisterCmd() *cobra.Command {
}
log.Printf("Central API credentials written to '%s'", dumpFile)
} else {
fmt.Printf("%s\n", string(apiConfigDump))
fmt.Println(string(apiConfigDump))
}
log.Warning(ReloadMessage())
@ -126,6 +128,7 @@ func (cli cliCapi) NewRegisterCmd() *cobra.Command {
cmd.Flags().StringVarP(&outputFile, "file", "f", "", "output file destination")
cmd.Flags().StringVar(&capiUserPrefix, "schmilblick", "", "set a schmilblick (use in tests only)")
if err := cmd.Flags().MarkHidden("schmilblick"); err != nil {
log.Fatalf("failed to hide flag: %s", err)
}
@ -134,18 +137,14 @@ func (cli cliCapi) NewRegisterCmd() *cobra.Command {
}
func (cli cliCapi) NewStatusCmd() *cobra.Command {
var cmd = &cobra.Command{
cmd := &cobra.Command{
Use: "status",
Short: "Check status with the Central API (CAPI)",
Args: cobra.MinimumNArgs(0),
DisableAutoGenTag: true,
RunE: func(cmd *cobra.Command, args []string) error {
if csConfig.API.Server.OnlineClient == nil {
return fmt.Errorf("please provide credentials for the Central API (CAPI) in '%s'", csConfig.API.Server.OnlineClient.CredentialsFilePath)
}
if csConfig.API.Server.OnlineClient.Credentials == nil {
return fmt.Errorf("no credentials for Central API (CAPI) in '%s'", csConfig.API.Server.OnlineClient.CredentialsFilePath)
RunE: func(_ *cobra.Command, _ []string) error {
if err := require.CAPIRegistered(csConfig); err != nil {
return err
}
password := strfmt.Password(csConfig.API.Server.OnlineClient.Credentials.Password)

View file

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/crowdsecurity/crowdsec/pkg/dumps"
"github.com/crowdsecurity/crowdsec/pkg/hubtest"
)
@ -35,7 +36,7 @@ func GetLineCountForFile(filepath string) (int, error) {
return lc, nil
}
type cliExplain struct {}
type cliExplain struct{}
func NewCLIExplain() *cliExplain {
return &cliExplain{}
@ -109,6 +110,7 @@ tail -n 5 myfile.log | cscli explain --type nginx -f -
flags.Bool("failures", false, "Only show failed lines")
flags.Bool("only-successful-parsers", false, "Only show successful parsers")
flags.String("crowdsec", "crowdsec", "Path to crowdsec")
flags.Bool("no-clean", false, "Don't clean runtime environment after tests")
return cmd
}
@ -136,13 +138,18 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
return err
}
opts := hubtest.DumpOpts{}
opts := dumps.DumpOpts{}
opts.Details, err = flags.GetBool("verbose")
if err != nil {
return err
}
no_clean, err := flags.GetBool("no-clean")
if err != nil {
return err
}
opts.SkipOk, err = flags.GetBool("failures")
if err != nil {
return err
@ -172,6 +179,9 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
return fmt.Errorf("couldn't create a temporary directory to store cscli explain result: %s", err)
}
defer func() {
if no_clean {
return
}
if _, err := os.Stat(dir); !os.IsNotExist(err) {
if err := os.RemoveAll(dir); err != nil {
log.Errorf("unable to delete temporary directory '%s': %s", dir, err)
@ -254,17 +264,17 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error {
parserDumpFile := filepath.Join(dir, hubtest.ParserResultFileName)
bucketStateDumpFile := filepath.Join(dir, hubtest.BucketPourResultFileName)
parserDump, err := hubtest.LoadParserDump(parserDumpFile)
parserDump, err := dumps.LoadParserDump(parserDumpFile)
if err != nil {
return fmt.Errorf("unable to load parser dump result: %s", err)
}
bucketStateDump, err := hubtest.LoadBucketPourDump(bucketStateDumpFile)
bucketStateDump, err := dumps.LoadBucketPourDump(bucketStateDumpFile)
if err != nil {
return fmt.Errorf("unable to load bucket dump result: %s", err)
}
hubtest.DumpTree(*parserDump, *bucketStateDump, opts)
dumps.DumpTree(*parserDump, *bucketStateDump, opts)
return nil
}

View file

@ -155,6 +155,7 @@ func (cli cliHub) upgrade(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
if didUpdate {
updated++
}
@ -191,18 +192,21 @@ func (cli cliHub) types(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
fmt.Print(string(s))
case "json":
jsonStr, err := json.Marshal(cwhub.ItemTypes)
if err != nil {
return err
}
fmt.Println(string(jsonStr))
case "raw":
for _, itemType := range cwhub.ItemTypes {
fmt.Println(itemType)
}
}
return nil
}

View file

@ -48,10 +48,11 @@ cscli appsec-configs list crowdsecurity/vpatch`,
func NewCLIAppsecRule() *cliItem {
inspectDetail := func(item *cwhub.Item) error {
//Only show the converted rules in human mode
// Only show the converted rules in human mode
if csConfig.Cscli.Output != "human" {
return nil
}
appsecRule := appsec.AppsecCollectionConfig{}
yamlContent, err := os.ReadFile(item.State.LocalPath)
@ -71,8 +72,16 @@ func NewCLIAppsecRule() *cliItem {
if err != nil {
return fmt.Errorf("unable to convert rule %s : %s", rule.Name, err)
}
fmt.Println(convertedRule)
}
switch ruleType { //nolint:gocritic
case appsec_rule.ModsecurityRuleType:
for _, rule := range appsecRule.SecLangRules {
fmt.Println(rule)
}
}
}
return nil

View file

@ -16,6 +16,7 @@ import (
"github.com/spf13/cobra"
"gopkg.in/yaml.v2"
"github.com/crowdsecurity/crowdsec/pkg/dumps"
"github.com/crowdsecurity/crowdsec/pkg/hubtest"
)
@ -679,8 +680,8 @@ func (cli cliHubTest) NewExplainCmd() *cobra.Command {
return fmt.Errorf("unable to load scenario result after run: %s", err)
}
}
opts := hubtest.DumpOpts{}
hubtest.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts)
opts := dumps.DumpOpts{}
dumps.DumpTree(*test.ParserAssert.TestData, *test.ScenarioAssert.PourData, opts)
}
return nil

View file

@ -2,11 +2,11 @@ package main
import (
"fmt"
"slices"
"strings"
"github.com/agext/levenshtein"
"github.com/spf13/cobra"
"slices"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"

View file

@ -7,10 +7,10 @@ import (
"io"
"os"
"path/filepath"
"slices"
"strings"
"gopkg.in/yaml.v3"
"slices"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
)

View file

@ -2,12 +2,13 @@ package main
import (
"os"
"slices"
"time"
"github.com/fatih/color"
cc "github.com/ivanpirog/coloredcobra"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"slices"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/database"
@ -107,7 +108,7 @@ var NoNeedConfig = []string{
func main() {
// set the formatter asap and worry about level later
logFormatter := &log.TextFormatter{TimestampFormat: "2006-01-02 15:04:05", FullTimestamp: true}
logFormatter := &log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true}
log.SetFormatter(logFormatter)
if err := fflag.RegisterAllFeatures(); err != nil {

43
cmd/crowdsec/hook.go Normal file
View file

@ -0,0 +1,43 @@
package main
import (
"io"
"os"
log "github.com/sirupsen/logrus"
)
type ConditionalHook struct {
Writer io.Writer
LogLevels []log.Level
Enabled bool
}
func (hook *ConditionalHook) Fire(entry *log.Entry) error {
if hook.Enabled {
line, err := entry.String()
if err != nil {
return err
}
_, err = hook.Writer.Write([]byte(line))
return err
}
return nil
}
func (hook *ConditionalHook) Levels() []log.Level {
return hook.LogLevels
}
// The primal logging hook is set up before parsing config.yaml.
// Once config.yaml is parsed, the primal hook is disabled if the
// configured logger is writing to stderr. Otherwise it's used to
// report fatal errors and panics to stderr in addition to the log file.
var primalHook = &ConditionalHook{
Writer: os.Stderr,
LogLevels: []log.Level{log.FatalLevel, log.PanicLevel},
Enabled: true,
}

View file

@ -80,11 +80,13 @@ func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error {
err error
files []string
)
for _, hubScenarioItem := range hub.GetItemMap(cwhub.SCENARIOS) {
if hubScenarioItem.State.Installed {
files = append(files, hubScenarioItem.State.LocalPath)
}
}
buckets = leakybucket.NewBuckets()
log.Infof("Loading %d scenario files", len(files))
@ -99,6 +101,7 @@ func LoadBuckets(cConfig *csconfig.Config, hub *cwhub.Hub) error {
holders[holderIndex].Profiling = true
}
}
return nil
}
@ -143,8 +146,10 @@ func (l labelsMap) Set(label string) error {
if len(split) != 2 {
return fmt.Errorf("invalid format for label '%s', must be key:value", pair)
}
l[split[0]] = split[1]
}
return nil
}
@ -168,9 +173,11 @@ func (f *Flags) Parse() {
flag.BoolVar(&f.DisableAPI, "no-api", false, "disable local API")
flag.BoolVar(&f.DisableCAPI, "no-capi", false, "disable communication with Central API")
flag.BoolVar(&f.OrderEvent, "order-event", false, "enforce event ordering with significant performance cost")
if runtime.GOOS == "windows" {
flag.StringVar(&f.WinSvc, "winsvc", "", "Windows service Action: Install, Remove etc..")
}
flag.StringVar(&dumpFolder, "dump-data", "", "dump parsers/buckets raw outputs")
flag.Parse()
}
@ -205,6 +212,7 @@ func newLogLevel(curLevelPtr *log.Level, f *Flags) *log.Level {
// avoid returning a new ptr to the same value
return curLevelPtr
}
return &ret
}
@ -238,6 +246,8 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo
return nil, err
}
primalHook.Enabled = (cConfig.Common.LogMedia != "stdout")
if err := csconfig.LoadFeatureFlagsFile(configFile, log.StandardLogger()); err != nil {
return nil, err
}
@ -282,6 +292,7 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo
if cConfig.DisableAPI {
cConfig.Common.Daemonize = false
}
log.Infof("single file mode : log_media=%s daemonize=%t", cConfig.Common.LogMedia, cConfig.Common.Daemonize)
}
@ -291,6 +302,7 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo
if cConfig.Common.Daemonize && runtime.GOOS == "windows" {
log.Debug("Daemonization is not supported on Windows, disabling")
cConfig.Common.Daemonize = false
}
@ -308,6 +320,8 @@ func LoadConfig(configFile string, disableAgent bool, disableAPI bool, quiet boo
var crowdsecT0 time.Time
func main() {
log.AddHook(primalHook)
if err := fflag.RegisterAllFeatures(); err != nil {
log.Fatalf("failed to register features: %s", err)
}
@ -342,5 +356,6 @@ func main() {
if err != nil {
log.Fatal(err)
}
os.Exit(0)
}

View file

@ -76,6 +76,15 @@ func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky
return fmt.Errorf("loading list of installed hub scenarios: %w", err)
}
appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES)
if err != nil {
return fmt.Errorf("loading list of installed hub appsec rules: %w", err)
}
installedScenariosAndAppsecRules := make([]string, 0, len(scenarios)+len(appsecRules))
installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, scenarios...)
installedScenariosAndAppsecRules = append(installedScenariosAndAppsecRules, appsecRules...)
apiURL, err := url.Parse(apiConfig.URL)
if err != nil {
return fmt.Errorf("parsing api url ('%s'): %w", apiConfig.URL, err)
@ -87,14 +96,27 @@ func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky
password := strfmt.Password(apiConfig.Password)
Client, err := apiclient.NewClient(&apiclient.Config{
MachineID: apiConfig.Login,
Password: password,
Scenarios: scenarios,
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
PapiURL: papiURL,
VersionPrefix: "v1",
UpdateScenario: func() ([]string, error) {return hub.GetInstalledItemNames(cwhub.SCENARIOS)},
MachineID: apiConfig.Login,
Password: password,
Scenarios: installedScenariosAndAppsecRules,
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
PapiURL: papiURL,
VersionPrefix: "v1",
UpdateScenario: func() ([]string, error) {
scenarios, err := hub.GetInstalledItemNames(cwhub.SCENARIOS)
if err != nil {
return nil, err
}
appsecRules, err := hub.GetInstalledItemNames(cwhub.APPSEC_RULES)
if err != nil {
return nil, err
}
ret := make([]string, 0, len(scenarios)+len(appsecRules))
ret = append(ret, scenarios...)
ret = append(ret, appsecRules...)
return ret, nil
},
})
if err != nil {
return fmt.Errorf("new client api: %w", err)
@ -102,7 +124,7 @@ func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky
authResp, _, err := Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
MachineID: &apiConfig.Login,
Password: &password,
Scenarios: scenarios,
Scenarios: installedScenariosAndAppsecRules,
})
if err != nil {
return fmt.Errorf("authenticate watcher (%s): %w", apiConfig.Login, err)

View file

@ -4,10 +4,8 @@ package main
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/writer"
"github.com/crowdsecurity/go-cs-lib/trace"
"github.com/crowdsecurity/go-cs-lib/version"
@ -24,16 +22,6 @@ func StartRunSvc() error {
defer trace.CatchPanic("crowdsec/StartRunSvc")
// Set a default logger with level=fatal on stderr,
// in addition to the one we configure afterwards
log.AddHook(&writer.Hook{
Writer: os.Stderr,
LogLevels: []log.Level{
log.PanicLevel,
log.FatalLevel,
},
})
if cConfig, err = LoadConfig(flags.ConfigFile, flags.DisableAgent, flags.DisableAPI, false); err != nil {
return err
}
@ -46,6 +34,7 @@ func StartRunSvc() error {
// Enable profiling early
if cConfig.Prometheus != nil {
var dbClient *database.Client
var err error
if cConfig.DbConfig != nil {
@ -55,8 +44,11 @@ func StartRunSvc() error {
return fmt.Errorf("unable to create database client: %s", err)
}
}
registerPrometheus(cConfig.Prometheus)
go servePrometheus(cConfig.Prometheus, dbClient, apiReady, agentReady)
}
return Serve(cConfig, apiReady, agentReady)
}

6
go.mod
View file

@ -24,9 +24,9 @@ require (
github.com/buger/jsonparser v1.1.1
github.com/c-robinson/iplib v1.0.3
github.com/cespare/xxhash/v2 v2.2.0
github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f
github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607
github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26
github.com/crowdsecurity/go-cs-lib v0.0.5
github.com/crowdsecurity/go-cs-lib v0.0.6
github.com/crowdsecurity/grokky v0.2.1
github.com/crowdsecurity/machineid v1.0.2
github.com/davecgh/go-spew v1.1.1
@ -201,7 +201,7 @@ require (
go.mongodb.org/mongo-driver v1.9.4 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/net v0.19.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/term v0.15.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901 // indirect

6
go.sum
View file

@ -100,10 +100,14 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f h1:FkOB9aDw0xzDd14pTarGRLsUNAymONq3dc7zhvsXElg=
github.com/crowdsecurity/coraza/v3 v3.0.0-20231213144607-41d5358da94f/go.mod h1:TrU7Li+z2RHNrPy0TKJ6R65V6Yzpan2sTIRryJJyJso=
github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607 h1:hyrYw3h8clMcRL2u5ooZ3tmwnmJftmhb9Ws1MKmavvI=
github.com/crowdsecurity/coraza/v3 v3.0.0-20240108124027-a62b8d8e5607/go.mod h1:br36fEqurGYZQGit+iDYsIzW0FF6VufMbDzyyLxEuPA=
github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen+7WnLs4xDScS/Ex988+id2k6mDf8psU=
github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk=
github.com/crowdsecurity/go-cs-lib v0.0.5 h1:eVLW+BRj3ZYn0xt5/xmgzfbbB8EBo32gM4+WpQQk2e8=
github.com/crowdsecurity/go-cs-lib v0.0.5/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k=
github.com/crowdsecurity/go-cs-lib v0.0.6 h1:Ef6MylXe0GaJE9vrfvxEdbHb31+JUP1os+murPz7Pos=
github.com/crowdsecurity/go-cs-lib v0.0.6/go.mod h1:8FMKNGsh3hMZi2SEv6P15PURhEJnZV431XjzzBSuf0k=
github.com/crowdsecurity/grokky v0.2.1 h1:t4VYnDlAd0RjDM2SlILalbwfCrQxtJSMGdQOR0zwkE4=
github.com/crowdsecurity/grokky v0.2.1/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM=
github.com/crowdsecurity/machineid v1.0.2 h1:wpkpsUghJF8Khtmn/tg6GxgdhLA1Xflerh5lirI+bdc=
@ -807,6 +811,8 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

View file

@ -1,5 +1,5 @@
BUILD_GOVERSION = $(subst go,,$(shell go env GOVERSION))
BUILD_GOVERSION = $(subst go,,$(shell $(GO) env GOVERSION))
go_major_minor = $(subst ., ,$(BUILD_GOVERSION))
GO_MAJOR_VERSION = $(word 1, $(go_major_minor))
@ -9,7 +9,7 @@ GO_VERSION_VALIDATION_ERR_MSG = Golang version ($(BUILD_GOVERSION)) is not suppo
.PHONY: goversion
goversion: $(if $(findstring devel,$(shell go env GOVERSION)),goversion_devel,goversion_check)
goversion: $(if $(findstring devel,$(shell $(GO) env GOVERSION)),goversion_devel,goversion_check)
.PHONY: goversion_devel

View file

@ -8,7 +8,11 @@ MKDIR=mkdir -p
GOOS ?= $(shell go env GOOS)
# Current versioning information from env
BUILD_VERSION?=$(shell git describe --tags)
# The $(or) is used to ignore an empty BUILD_VERSION when it's an envvar,
# like inside a docker build: docker build --build-arg BUILD_VERSION=1.2.3
# as opposed to a make parameter: make BUILD_VERSION=1.2.3
BUILD_VERSION:=$(or $(BUILD_VERSION),$(shell git describe --tags --dirty))
BUILD_TIMESTAMP=$(shell date +%F"_"%T)
DEFAULT_CONFIGDIR?=/etc/crowdsec
DEFAULT_DATADIR?=/var/lib/crowdsec/data

View file

@ -40,15 +40,19 @@ func (f *MockSource) Configure(cfg []byte, logger *log.Entry) error {
if err := f.UnmarshalConfig(cfg); err != nil {
return err
}
if f.Mode == "" {
f.Mode = configuration.CAT_MODE
}
if f.Mode != configuration.CAT_MODE && f.Mode != configuration.TAIL_MODE {
return fmt.Errorf("mode %s is not supported", f.Mode)
}
if f.Toto == "" {
return fmt.Errorf("expect non-empty toto")
}
return nil
}
func (f *MockSource) GetMode() string { return f.Mode }
@ -77,6 +81,7 @@ func appendMockSource() {
if GetDataSourceIface("mock") == nil {
AcquisitionSources["mock"] = func() DataSource { return &MockSource{} }
}
if GetDataSourceIface("mock_cant_run") == nil {
AcquisitionSources["mock_cant_run"] = func() DataSource { return &MockSourceCantRun{} }
}
@ -84,6 +89,7 @@ func appendMockSource() {
func TestDataSourceConfigure(t *testing.T) {
appendMockSource()
tests := []struct {
TestName string
String string
@ -185,22 +191,22 @@ wowo: ajsajasjas
switch tc.TestName {
case "basic_valid_config":
mock := (*ds).Dump().(*MockSource)
assert.Equal(t, mock.Toto, "test_value1")
assert.Equal(t, mock.Mode, "cat")
assert.Equal(t, mock.logger.Logger.Level, log.InfoLevel)
assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"})
assert.Equal(t, "test_value1", mock.Toto)
assert.Equal(t, "cat", mock.Mode)
assert.Equal(t, log.InfoLevel, mock.logger.Logger.Level)
assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels)
case "basic_debug_config":
mock := (*ds).Dump().(*MockSource)
assert.Equal(t, mock.Toto, "test_value1")
assert.Equal(t, mock.Mode, "cat")
assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel)
assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"})
assert.Equal(t, "test_value1", mock.Toto)
assert.Equal(t, "cat", mock.Mode)
assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level)
assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels)
case "basic_tailmode_config":
mock := (*ds).Dump().(*MockSource)
assert.Equal(t, mock.Toto, "test_value1")
assert.Equal(t, mock.Mode, "tail")
assert.Equal(t, mock.logger.Logger.Level, log.DebugLevel)
assert.Equal(t, mock.Labels, map[string]string{"test": "foobar"})
assert.Equal(t, "test_value1", mock.Toto)
assert.Equal(t, "tail", mock.Mode)
assert.Equal(t, log.DebugLevel, mock.logger.Logger.Level)
assert.Equal(t, map[string]string{"test": "foobar"}, mock.Labels)
}
})
}
@ -208,6 +214,7 @@ wowo: ajsajasjas
func TestLoadAcquisitionFromFile(t *testing.T) {
appendMockSource()
tests := []struct {
TestName string
Config csconfig.CrowdsecServiceCfg
@ -284,7 +291,6 @@ func TestLoadAcquisitionFromFile(t *testing.T) {
assert.Len(t, dss, tc.ExpectedLen)
})
}
}
@ -304,9 +310,11 @@ func (f *MockCat) Configure(cfg []byte, logger *log.Entry) error {
if f.Mode == "" {
f.Mode = configuration.CAT_MODE
}
if f.Mode != configuration.CAT_MODE {
return fmt.Errorf("mode %s is not supported", f.Mode)
}
return nil
}
@ -319,6 +327,7 @@ func (f *MockCat) OneShotAcquisition(out chan types.Event, tomb *tomb.Tomb) erro
evt.Line.Src = "test"
out <- evt
}
return nil
}
func (f *MockCat) StreamingAcquisition(chan types.Event, *tomb.Tomb) error {
@ -345,9 +354,11 @@ func (f *MockTail) Configure(cfg []byte, logger *log.Entry) error {
if f.Mode == "" {
f.Mode = configuration.TAIL_MODE
}
if f.Mode != configuration.TAIL_MODE {
return fmt.Errorf("mode %s is not supported", f.Mode)
}
return nil
}
@ -364,6 +375,7 @@ func (f *MockTail) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) erro
out <- evt
}
<-t.Dying()
return nil
}
func (f *MockTail) CanRun() error { return nil }
@ -446,6 +458,7 @@ func (f *MockTailError) StreamingAcquisition(out chan types.Event, t *tomb.Tomb)
out <- evt
}
t.Kill(fmt.Errorf("got error (tomb)"))
return fmt.Errorf("got error")
}
@ -499,6 +512,7 @@ func (f *MockSourceByDSN) ConfigureByDSN(dsn string, labels map[string]string, l
if dsn != "test_expect" {
return fmt.Errorf("unexpected value")
}
return nil
}
func (f *MockSourceByDSN) GetUuid() string { return "" }

View file

@ -335,7 +335,7 @@ func (w *AppsecSource) appsecHandler(rw http.ResponseWriter, r *http.Request) {
}
// parse the request only once
parsedRequest, err := appsec.NewParsedRequestFromRequest(r)
parsedRequest, err := appsec.NewParsedRequestFromRequest(r, w.logger)
if err != nil {
w.logger.Errorf("%s", err)
rw.WriteHeader(http.StatusInternalServerError)

View file

@ -57,6 +57,7 @@ container_name:
subLogger := log.WithFields(log.Fields{
"type": "docker",
})
for _, test := range tests {
f := DockerSource{}
err := f.Configure([]byte(test.config), subLogger)
@ -66,12 +67,15 @@ container_name:
func TestConfigureDSN(t *testing.T) {
log.Infof("Test 'TestConfigureDSN'")
var dockerHost string
if runtime.GOOS == "windows" {
dockerHost = "npipe:////./pipe/docker_engine"
} else {
dockerHost = "unix:///var/run/podman/podman.sock"
}
tests := []struct {
name string
dsn string
@ -106,6 +110,7 @@ func TestConfigureDSN(t *testing.T) {
subLogger := log.WithFields(log.Fields{
"type": "docker",
})
for _, test := range tests {
f := DockerSource{}
err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "")
@ -156,8 +161,11 @@ container_name_regexp:
}
for _, ts := range tests {
var logger *log.Logger
var subLogger *log.Entry
var (
logger *log.Logger
subLogger *log.Entry
)
if ts.expectedOutput != "" {
logger.SetLevel(ts.logLevel)
subLogger = logger.WithFields(log.Fields{
@ -173,10 +181,12 @@ container_name_regexp:
dockerTomb := tomb.Tomb{}
out := make(chan types.Event)
dockerSource := DockerSource{}
err := dockerSource.Configure([]byte(ts.config), subLogger)
if err != nil {
t.Fatalf("Unexpected error : %s", err)
}
dockerSource.Client = new(mockDockerCli)
actualLines := 0
readerTomb := &tomb.Tomb{}
@ -204,21 +214,23 @@ container_name_regexp:
if err := readerTomb.Wait(); err != nil {
t.Fatal(err)
}
if ts.expectedLines != 0 {
assert.Equal(t, ts.expectedLines, actualLines)
}
err = streamTomb.Wait()
if err != nil {
t.Fatalf("docker acquisition error: %s", err)
}
}
}
func (cli *mockDockerCli) ContainerList(ctx context.Context, options dockerTypes.ContainerListOptions) ([]dockerTypes.Container, error) {
if readLogs == true {
return []dockerTypes.Container{}, nil
}
containers := make([]dockerTypes.Container, 0)
container := &dockerTypes.Container{
ID: "12456",
@ -233,16 +245,20 @@ func (cli *mockDockerCli) ContainerLogs(ctx context.Context, container string, o
if readLogs == true {
return io.NopCloser(strings.NewReader("")), nil
}
readLogs = true
data := []string{"docker\n", "test\n", "1234\n"}
ret := ""
for _, line := range data {
startLineByte := make([]byte, 8)
binary.LittleEndian.PutUint32(startLineByte, 1) //stdout stream
binary.BigEndian.PutUint32(startLineByte[4:], uint32(len(line)))
ret += fmt.Sprintf("%s%s", startLineByte, line)
}
r := io.NopCloser(strings.NewReader(ret)) // r type is io.ReadCloser
return r, nil
}
@ -252,6 +268,7 @@ func (cli *mockDockerCli) ContainerInspect(ctx context.Context, c string) (docke
Tty: false,
},
}
return r, nil
}
@ -285,8 +302,11 @@ func TestOneShot(t *testing.T) {
}
for _, ts := range tests {
var subLogger *log.Entry
var logger *log.Logger
var (
subLogger *log.Entry
logger *log.Logger
)
if ts.expectedOutput != "" {
logger.SetLevel(ts.logLevel)
subLogger = logger.WithFields(log.Fields{
@ -307,6 +327,7 @@ func TestOneShot(t *testing.T) {
if err := dockerClient.ConfigureByDSN(ts.dsn, labels, subLogger, ""); err != nil {
t.Fatalf("unable to configure dsn '%s': %s", ts.dsn, err)
}
dockerClient.Client = new(mockDockerCli)
out := make(chan types.Event, 100)
tomb := tomb.Tomb{}
@ -315,8 +336,7 @@ func TestOneShot(t *testing.T) {
// else we do the check before actualLines is incremented ...
if ts.expectedLines != 0 {
assert.Equal(t, ts.expectedLines, len(out))
assert.Len(t, out, ts.expectedLines)
}
}
}

View file

@ -21,6 +21,7 @@ func TestBadConfiguration(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows")
}
tests := []struct {
config string
expectedErr string
@ -48,6 +49,7 @@ journalctl_filter:
subLogger := log.WithFields(log.Fields{
"type": "journalctl",
})
for _, test := range tests {
f := JournalCtlSource{}
err := f.Configure([]byte(test.config), subLogger)
@ -59,6 +61,7 @@ func TestConfigureDSN(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows")
}
tests := []struct {
dsn string
expectedErr string
@ -92,9 +95,11 @@ func TestConfigureDSN(t *testing.T) {
expectedErr: "",
},
}
subLogger := log.WithFields(log.Fields{
"type": "journalctl",
})
for _, test := range tests {
f := JournalCtlSource{}
err := f.ConfigureByDSN(test.dsn, map[string]string{"type": "testtype"}, subLogger, "")
@ -106,6 +111,7 @@ func TestOneShot(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows")
}
tests := []struct {
config string
expectedErr string
@ -137,9 +143,12 @@ journalctl_filter:
},
}
for _, ts := range tests {
var logger *log.Logger
var subLogger *log.Entry
var hook *test.Hook
var (
logger *log.Logger
subLogger *log.Entry
hook *test.Hook
)
if ts.expectedOutput != "" {
logger, hook = test.NewNullLogger()
logger.SetLevel(ts.logLevel)
@ -151,27 +160,32 @@ journalctl_filter:
"type": "journalctl",
})
}
tomb := tomb.Tomb{}
out := make(chan types.Event, 100)
j := JournalCtlSource{}
err := j.Configure([]byte(ts.config), subLogger)
if err != nil {
t.Fatalf("Unexpected error : %s", err)
}
err = j.OneShotAcquisition(out, &tomb)
cstest.AssertErrorContains(t, err, ts.expectedErr)
if err != nil {
continue
}
if ts.expectedLines != 0 {
assert.Equal(t, ts.expectedLines, len(out))
assert.Len(t, out, ts.expectedLines)
}
if ts.expectedOutput != "" {
if hook.LastEntry() == nil {
t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput)
}
assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
hook.Reset()
}
@ -182,6 +196,7 @@ func TestStreaming(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping test on windows")
}
tests := []struct {
config string
expectedErr string
@ -202,9 +217,12 @@ journalctl_filter:
},
}
for _, ts := range tests {
var logger *log.Logger
var subLogger *log.Entry
var hook *test.Hook
var (
logger *log.Logger
subLogger *log.Entry
hook *test.Hook
)
if ts.expectedOutput != "" {
logger, hook = test.NewNullLogger()
logger.SetLevel(ts.logLevel)
@ -216,14 +234,18 @@ journalctl_filter:
"type": "journalctl",
})
}
tomb := tomb.Tomb{}
out := make(chan types.Event)
j := JournalCtlSource{}
err := j.Configure([]byte(ts.config), subLogger)
if err != nil {
t.Fatalf("Unexpected error : %s", err)
}
actualLines := 0
if ts.expectedLines != 0 {
go func() {
READLOOP:
@ -240,6 +262,7 @@ journalctl_filter:
err = j.StreamingAcquisition(out, &tomb)
cstest.AssertErrorContains(t, err, ts.expectedErr)
if err != nil {
continue
}
@ -248,16 +271,20 @@ journalctl_filter:
time.Sleep(1 * time.Second)
assert.Equal(t, ts.expectedLines, actualLines)
}
tomb.Kill(nil)
tomb.Wait()
output, _ := exec.Command("pgrep", "-x", "journalctl").CombinedOutput()
if string(output) != "" {
t.Fatalf("Found a journalctl process after killing the tomb !")
}
if ts.expectedOutput != "" {
if hook.LastEntry() == nil {
t.Fatalf("Expected log output '%s' but got nothing !", ts.expectedOutput)
}
assert.Contains(t, hook.LastEntry().Message, ts.expectedOutput)
hook.Reset()
}
@ -270,5 +297,6 @@ func TestMain(m *testing.M) {
fullPath := filepath.Join(currentDir, "test_files")
os.Setenv("PATH", fullPath+":"+os.Getenv("PATH"))
}
os.Exit(m.Run())
}

View file

@ -9,6 +9,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/tomb.v2"
)
@ -78,24 +79,23 @@ webhook_path: /k8s-audit`,
err := f.UnmarshalConfig([]byte(test.config))
assert.NoError(t, err)
require.NoError(t, err)
err = f.Configure([]byte(test.config), subLogger)
assert.NoError(t, err)
require.NoError(t, err)
f.StreamingAcquisition(out, tb)
time.Sleep(1 * time.Second)
tb.Kill(nil)
err = tb.Wait()
if test.expectedErr != "" {
assert.ErrorContains(t, err, test.expectedErr)
require.ErrorContains(t, err, test.expectedErr)
return
}
assert.NoError(t, err)
require.NoError(t, err)
})
}
}
func TestHandler(t *testing.T) {
@ -252,10 +252,10 @@ webhook_path: /k8s-audit`,
f := KubernetesAuditSource{}
err := f.UnmarshalConfig([]byte(test.config))
assert.NoError(t, err)
require.NoError(t, err)
err = f.Configure([]byte(test.config), subLogger)
assert.NoError(t, err)
require.NoError(t, err)
req := httptest.NewRequest(test.method, "/k8s-audit", strings.NewReader(test.body))
w := httptest.NewRecorder()
@ -268,11 +268,11 @@ webhook_path: /k8s-audit`,
assert.Equal(t, test.expectedStatusCode, res.StatusCode)
//time.Sleep(1 * time.Second)
assert.NoError(t, err)
require.NoError(t, err)
tb.Kill(nil)
err = tb.Wait()
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, test.eventCount, eventCount)
})

View file

@ -11,6 +11,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/svc/eventlog"
"gopkg.in/tomb.v2"
)
@ -124,7 +125,7 @@ event_level: bla`,
}
assert.Contains(t, err.Error(), test.expectedErr)
} else {
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, test.expectedQuery, q)
}
}
@ -221,9 +222,8 @@ event_ids:
}
}
if test.expectedLines == nil {
assert.Equal(t, 0, len(linesRead))
assert.Empty(t, linesRead)
} else {
assert.Equal(t, len(test.expectedLines), len(linesRead))
assert.Equal(t, test.expectedLines, linesRead)
}
to.Kill(nil)

View file

@ -7,6 +7,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAlertContext(t *testing.T) {
@ -29,8 +30,7 @@ func TestNewAlertContext(t *testing.T) {
for _, test := range tests {
fmt.Printf("Running test '%s'\n", test.name)
err := NewAlertContext(test.contextToSend, test.valueLength)
assert.ErrorIs(t, err, test.expectedErr)
require.ErrorIs(t, err, test.expectedErr)
}
}
@ -193,7 +193,7 @@ func TestEventToContext(t *testing.T) {
for _, test := range tests {
fmt.Printf("Running test '%s'\n", test.name)
err := NewAlertContext(test.contextToSend, test.valueLength)
assert.ErrorIs(t, err, nil)
require.NoError(t, err)
metas, _ := EventToContext(test.events)
assert.ElementsMatch(t, test.expectedResult, metas)

View file

@ -2,7 +2,9 @@ package alertcontext
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"slices"
@ -14,6 +16,8 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
)
var ErrNoContextData = errors.New("no context to send")
// this file is here to avoid circular dependencies between the configuration and the hub
// HubItemWrapper is a wrapper around a hub item to unmarshal only the context part
@ -25,7 +29,7 @@ type HubItemWrapper struct {
// mergeContext adds the context from src to dest.
func mergeContext(dest map[string][]string, src map[string][]string) error {
if len(src) == 0 {
return fmt.Errorf("no context data to merge")
return ErrNoContextData
}
for k, v := range src {
@ -86,8 +90,9 @@ func addContextFromFile(toSend map[string][]string, filePath string) error {
}
err = mergeContext(toSend, newContext)
if err != nil {
log.Warningf("while merging context from %s: %s", filePath, err)
if err != nil && !errors.Is(err, ErrNoContextData) {
// having an empty console/context.yaml is not an error
return err
}
return nil
@ -125,8 +130,10 @@ func LoadConsoleContext(c *csconfig.Config, hub *cwhub.Hub) error {
}
if err := addContextFromFile(c.Crowdsec.ContextToSend, c.Crowdsec.ConsoleContextPath); err != nil {
if !ignoreMissing || !os.IsNotExist(err) {
if !errors.Is(err, fs.ErrNotExist) {
return err
} else if !ignoreMissing {
log.Warningf("while merging context from %s: %s", c.Crowdsec.ConsoleContextPath, err)
}
}

View file

@ -5,13 +5,14 @@ import (
"fmt"
"net/http"
"net/url"
"reflect"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/version"
"github.com/crowdsecurity/crowdsec/pkg/models"
@ -25,12 +26,11 @@ func TestAlertsListAsMachine(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
log.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -39,19 +39,16 @@ func TestAlertsListAsMachine(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
log.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
defer teardown()
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
if r.URL.RawQuery == "ip=1.2.3.4" {
testMethod(t, r, "GET")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `null`)
return
}
@ -107,36 +104,26 @@ func TestAlertsListAsMachine(t *testing.T) {
]`)
})
tcapacity := int32(5)
tduration := "59m49.264032632s"
torigin := "crowdsec"
tscenario := "crowdsecurity/ssh-bf"
tscope := "Ip"
ttype := "ban"
tvalue := "1.1.1.172"
ttimestamp := "2020-11-28 10:20:46 +0000 UTC"
teventscount := int32(6)
tleakspeed := "10s"
tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761"
tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"
tscenarioversion := "0.1"
tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100"
tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100"
expected := models.GetAlertsResponse{
&models.Alert{
Capacity: &tcapacity,
Capacity: ptr.Of(int32(5)),
CreatedAt: "2020-11-28T10:20:47+01:00",
Decisions: []*models.Decision{
{
Duration: &tduration,
Duration: ptr.Of("59m49.264032632s"),
ID: 1,
Origin: &torigin,
Origin: ptr.Of("crowdsec"),
Scenario: &tscenario,
Scope: &tscope,
Simulated: new(bool), //false,
Type: &ttype,
Simulated: ptr.Of(false),
Type: ptr.Of("ban"),
Value: &tvalue,
},
},
@ -167,16 +154,16 @@ func TestAlertsListAsMachine(t *testing.T) {
Timestamp: &ttimestamp,
},
},
EventsCount: &teventscount,
EventsCount: ptr.Of(int32(6)),
ID: 1,
Leakspeed: &tleakspeed,
Leakspeed: ptr.Of("10s"),
MachineID: "test",
Message: &tmessage,
Remediation: false,
Scenario: &tscenario,
ScenarioHash: &tscenariohash,
ScenarioVersion: &tscenarioversion,
Simulated: new(bool), //(false),
ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"),
ScenarioVersion: ptr.Of("0.1"),
Simulated: ptr.Of(false),
Source: &models.Source{
AsName: "Cloudflare Inc",
AsNumber: "",
@ -188,8 +175,8 @@ func TestAlertsListAsMachine(t *testing.T) {
Scope: &tscope,
Value: &tvalue,
},
StartAt: &tstartat,
StopAt: &tstopat,
StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"),
StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"),
},
}
@ -198,30 +185,16 @@ func TestAlertsListAsMachine(t *testing.T) {
//log.Debugf("expected : -> %s", spew.Sdump(expected))
//first one returns data
alerts, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{})
if err != nil {
log.Errorf("test Unable to list alerts : %+v", err)
}
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, expected, *alerts)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if !reflect.DeepEqual(*alerts, expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
}
//this one doesn't
filter := AlertsListOpts{IPEquals: new(string)}
*filter.IPEquals = "1.2.3.4"
filter := AlertsListOpts{IPEquals: ptr.Of("1.2.3.4")}
alerts, resp, err = client.Alerts.List(context.Background(), filter)
if err != nil {
log.Errorf("test Unable to list alerts : %+v", err)
}
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Empty(t, *alerts)
}
@ -236,9 +209,7 @@ func TestAlertsGetAsMachine(t *testing.T) {
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
log.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -247,12 +218,10 @@ func TestAlertsGetAsMachine(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
log.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
defer teardown()
mux.HandleFunc("/alerts/2", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
w.WriteHeader(http.StatusNotFound)
@ -312,34 +281,24 @@ func TestAlertsGetAsMachine(t *testing.T) {
}`)
})
tcapacity := int32(5)
tduration := "59m49.264032632s"
torigin := "crowdsec"
tscenario := "crowdsecurity/ssh-bf"
tscope := "Ip"
ttype := "ban"
tvalue := "1.1.1.172"
ttimestamp := "2020-11-28 10:20:46 +0000 UTC"
teventscount := int32(6)
tleakspeed := "10s"
tmessage := "Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761"
tscenariohash := "4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"
tscenarioversion := "0.1"
tstartat := "2020-11-28 10:20:46.842701127 +0100 +0100"
tstopat := "2020-11-28 10:20:46.845621385 +0100 +0100"
expected := &models.Alert{
Capacity: &tcapacity,
Capacity: ptr.Of(int32(5)),
CreatedAt: "2020-11-28T10:20:47+01:00",
Decisions: []*models.Decision{
{
Duration: &tduration,
Duration: ptr.Of("59m49.264032632s"),
ID: 1,
Origin: &torigin,
Origin: ptr.Of("crowdsec"),
Scenario: &tscenario,
Scope: &tscope,
Simulated: new(bool), //false,
Simulated: ptr.Of(false),
Type: &ttype,
Value: &tvalue,
},
@ -371,16 +330,16 @@ func TestAlertsGetAsMachine(t *testing.T) {
Timestamp: &ttimestamp,
},
},
EventsCount: &teventscount,
EventsCount: ptr.Of(int32(6)),
ID: 1,
Leakspeed: &tleakspeed,
Leakspeed: ptr.Of("10s"),
MachineID: "test",
Message: &tmessage,
Message: ptr.Of("Ip 1.1.1.172 performed 'crowdsecurity/ssh-bf' (6 events over 2.920062ms) at 2020-11-28 10:20:46.845619968 +0100 CET m=+5.903899761"),
Remediation: false,
Scenario: &tscenario,
ScenarioHash: &tscenariohash,
ScenarioVersion: &tscenarioversion,
Simulated: new(bool), //(false),
ScenarioHash: ptr.Of("4441dcff07020f6690d998b7101e642359ba405c2abb83565bbbdcee36de280f"),
ScenarioVersion: ptr.Of("0.1"),
Simulated: ptr.Of(false),
Source: &models.Source{
AsName: "Cloudflare Inc",
AsNumber: "",
@ -392,24 +351,18 @@ func TestAlertsGetAsMachine(t *testing.T) {
Scope: &tscope,
Value: &tvalue,
},
StartAt: &tstartat,
StopAt: &tstopat,
StartAt: ptr.Of("2020-11-28 10:20:46.842701127 +0100 +0100"),
StopAt: ptr.Of("2020-11-28 10:20:46.845621385 +0100 +0100"),
}
alerts, resp, err := client.Alerts.GetByID(context.Background(), 1)
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if !reflect.DeepEqual(*alerts, *expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *alerts)
//fail
_, _, err = client.Alerts.GetByID(context.Background(), 2)
assert.Contains(t, fmt.Sprintf("%s", err), "API error: object not found")
cstest.RequireErrorMessage(t, err, "API error: object not found")
}
func TestAlertsCreateAsMachine(t *testing.T) {
@ -420,17 +373,17 @@ func TestAlertsCreateAsMachine(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`["3"]`))
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
log.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -439,10 +392,7 @@ func TestAlertsCreateAsMachine(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
log.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
defer teardown()
@ -452,13 +402,8 @@ func TestAlertsCreateAsMachine(t *testing.T) {
expected := &models.AddAlertsResponse{"3"}
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if !reflect.DeepEqual(*alerts, *expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *alerts)
}
func TestAlertsDeleteAsMachine(t *testing.T) {
@ -469,18 +414,18 @@ func TestAlertsDeleteAsMachine(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "DELETE")
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message":"0 deleted alerts"}`))
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
log.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -489,25 +434,16 @@ func TestAlertsDeleteAsMachine(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
log.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
defer teardown()
alert := AlertsDeleteOpts{IPEquals: new(string)}
*alert.IPEquals = "1.2.3.4"
alert := AlertsDeleteOpts{IPEquals: ptr.Of("1.2.3.4")}
alerts, resp, err := client.Alerts.Delete(context.Background(), alert)
require.NoError(t, err)
expected := &models.DeleteAlertsResponse{NbDeleted: ""}
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if !reflect.DeepEqual(*alerts, *expected) {
t.Errorf("client.Alerts.List returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *alerts)
}

View file

@ -3,6 +3,7 @@ package apiclient
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
@ -13,7 +14,6 @@ import (
"time"
"github.com/go-openapi/strfmt"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/crowdsec/pkg/fflag"
@ -52,10 +52,12 @@ func (t *APIKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("auth-api request: %s", string(dump))
}
// Make the HTTP request.
resp, err := t.transport().RoundTrip(req)
if err != nil {
log.Errorf("auth-api: auth with api key failed return nil response, error: %s", err)
return resp, err
}
@ -115,10 +117,12 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
for i := 0; i < maxAttempts; i++ {
if i > 0 {
if r.withBackOff {
//nolint:gosec
backoff += 10 + rand.Intn(20)
}
log.Infof("retrying in %d seconds (attempt %d of %d)", backoff, i+1, r.maxAttempts)
select {
case <-req.Context().Done():
return resp, req.Context().Err()
@ -134,8 +138,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
resp, err = r.next.RoundTrip(clonedReq)
if err != nil {
left := maxAttempts - i - 1
if left > 0 {
if left := maxAttempts - i - 1; left > 0 {
log.Errorf("error while performing request: %s; %d retries left", err, left)
}
@ -177,7 +180,7 @@ func (t *JWTTransport) refreshJwtToken() error {
log.Debugf("scenarios list updated for '%s'", *t.MachineID)
}
var auth = models.WatcherAuthRequest{
auth := models.WatcherAuthRequest{
MachineID: t.MachineID,
Password: t.Password,
Scenarios: t.Scenarios,
@ -264,13 +267,14 @@ func (t *JWTTransport) refreshJwtToken() error {
// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
// In a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI
// we use a mutex to avoid this
//We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
// We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request)
t.refreshTokenMutex.Lock()
if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) {
if err := t.refreshJwtToken(); err != nil {
t.refreshTokenMutex.Unlock()
return nil, err
}
}
@ -296,8 +300,9 @@ func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
if err != nil {
/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
// we had an error (network error for example, or 401 because token is refused), reset the token ?
t.Token = ""
return resp, fmt.Errorf("performing jwt auth: %w", err)
}
@ -355,6 +360,7 @@ func cloneRequest(r *http.Request) *http.Request {
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}

View file

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/version"
@ -24,13 +25,11 @@ type BasicMockPayload struct {
}
func getLoginsForMockErrorCases() map[string]int {
loginsForMockErrorCases := map[string]int{
return map[string]int{
"login_400": http.StatusBadRequest,
"login_409": http.StatusConflict,
"login_500": http.StatusInternalServerError,
}
return loginsForMockErrorCases
}
func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
@ -49,7 +48,7 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
w.WriteHeader(http.StatusBadRequest)
}
responseBody := ""
var responseBody string
responseCode, hasFoundErrorMock := loginsForMockErrorCases[payload.MachineID]
if !hasFoundErrorMock {
@ -58,6 +57,7 @@ func initBasicMuxMock(t *testing.T, mux *http.ServeMux, path string) {
} else {
responseBody = fmt.Sprintf("Error %d", responseCode)
}
log.Printf("MockServerReceived > %s // Login : [%s] => Mux response [%d]", newStr, payload.MachineID, responseCode)
w.WriteHeader(responseCode)
@ -76,14 +76,13 @@ func TestWatcherRegister(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
//body: models.WatcherRegistrationRequest{MachineID: &config.MachineID, Password: &config.Password}
initBasicMuxMock(t, mux, "/watchers")
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
// Valid Registration : should retrieve the client and no err
clientconfig := Config{
@ -95,9 +94,7 @@ func TestWatcherRegister(t *testing.T) {
}
client, err := RegisterClient(&clientconfig, &http.Client{})
if client == nil || err != nil {
t.Fatalf("while registering client : %s", err)
}
require.NoError(t, err)
log.Printf("->%T", client)
@ -107,11 +104,8 @@ func TestWatcherRegister(t *testing.T) {
clientconfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
client, err = RegisterClient(&clientconfig, &http.Client{})
if client != nil || err == nil {
t.Fatalf("The RegisterClient function should have returned an error for the response code %d", errorCodeToTest)
} else {
log.Printf("The RegisterClient function handled the error code %d as expected \n\r", errorCodeToTest)
}
require.Nil(t, client, "nil expected for the response code %d", errorCodeToTest)
require.Error(t, err, "error expected for the response code %d", errorCodeToTest)
}
}
@ -126,9 +120,7 @@ func TestWatcherAuth(t *testing.T) {
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok auth
clientConfig := &Config{
@ -139,34 +131,27 @@ func TestWatcherAuth(t *testing.T) {
VersionPrefix: "v1",
Scenarios: []string{"crowdsecurity/test"},
}
client, err := NewClient(clientConfig)
if err != nil {
t.Fatalf("new api client: %s", err)
}
client, err := NewClient(clientConfig)
require.NoError(t, err)
_, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
MachineID: &clientConfig.MachineID,
Password: &clientConfig.Password,
Scenarios: clientConfig.Scenarios,
})
if err != nil {
t.Fatalf("unexpect auth err 0: %s", err)
}
require.NoError(t, err)
// Testing error handling on AuthenticateWatcher (400, 409): should retrieve an error
// Not testing 500 because it loops and try to re-autehnticate. But you can test it manually by adding it in array
errorCodesToTest := [2]int{http.StatusBadRequest, http.StatusConflict}
for _, errorCodeToTest := range errorCodesToTest {
clientConfig.MachineID = fmt.Sprintf("login_%d", errorCodeToTest)
client, err := NewClient(clientConfig)
require.NoError(t, err)
if err != nil {
t.Fatalf("new api client: %s", err)
}
var resp *Response
_, resp, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
_, resp, err := client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{
MachineID: &clientConfig.MachineID,
Password: &clientConfig.Password,
})
@ -175,9 +160,7 @@ func TestWatcherAuth(t *testing.T) {
resp.Response.Body.Close()
bodyBytes, err := io.ReadAll(resp.Response.Body)
if err != nil {
t.Fatalf("error while reading body: %s", err.Error())
}
require.NoError(t, err)
log.Printf(string(bodyBytes))
t.Fatalf("The AuthenticateWatcher function should have returned an error for the response code %d", errorCodeToTest)
@ -199,10 +182,12 @@ func TestWatcherUnregister(t *testing.T) {
assert.Equal(t, int64(0), r.ContentLength)
w.WriteHeader(http.StatusOK)
})
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST")
buf := new(bytes.Buffer)
_, _ = buf.ReadFrom(r.Body)
newStr := buf.String()
if newStr == `{"machine_id":"test_login","password":"test_password","scenarios":["crowdsecurity/test"]}
` {
@ -217,9 +202,7 @@ func TestWatcherUnregister(t *testing.T) {
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
mycfg := &Config{
MachineID: "test_login",
@ -229,16 +212,12 @@ func TestWatcherUnregister(t *testing.T) {
VersionPrefix: "v1",
Scenarios: []string{"crowdsecurity/test"},
}
client, err := NewClient(mycfg)
if err != nil {
t.Fatalf("new api client: %s", err)
}
client, err := NewClient(mycfg)
require.NoError(t, err)
_, err = client.Auth.UnregisterWatcher(context.Background())
if err != nil {
t.Fatalf("while registering client : %s", err)
}
require.NoError(t, err)
log.Printf("->%T", client)
}
@ -255,6 +234,7 @@ func TestWatcherEnroll(t *testing.T) {
_, _ = buf.ReadFrom(r.Body)
newStr := buf.String()
log.Debugf("body -> %s", newStr)
if newStr == `{"attachment_key":"goodkey","name":"","tags":[],"overwrite":false}
` {
log.Print("good key")
@ -266,17 +246,17 @@ func TestWatcherEnroll(t *testing.T) {
fmt.Fprintf(w, `{"message":"the attachment key provided is not valid"}`)
}
})
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"code":200,"expire":"2029-11-30T14:14:24+01:00","token":"toto"}`)
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
mycfg := &Config{
MachineID: "test_login",
@ -286,16 +266,12 @@ func TestWatcherEnroll(t *testing.T) {
VersionPrefix: "v1",
Scenarios: []string{"crowdsecurity/test"},
}
client, err := NewClient(mycfg)
if err != nil {
t.Fatalf("new api client: %s", err)
}
client, err := NewClient(mycfg)
require.NoError(t, err)
_, err = client.Auth.EnrollWatcher(context.Background(), "goodkey", "", []string{}, false)
if err != nil {
t.Fatalf("unexpect enroll err: %s", err)
}
require.NoError(t, err)
_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
assert.Contains(t, err.Error(), "the attachment key provided is not valid", "got %s", err.Error())

View file

@ -9,6 +9,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/ptr"
)
func TestApiAuth(t *testing.T) {
@ -17,6 +20,7 @@ func TestApiAuth(t *testing.T) {
mux, urlx, teardown := setup()
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
if r.Header.Get("X-Api-Key") == "ixu" {
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
w.WriteHeader(http.StatusOK)
@ -26,11 +30,11 @@ func TestApiAuth(t *testing.T) {
w.Write([]byte(`{"message":"access forbidden"}`))
}
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
defer teardown()
@ -40,18 +44,12 @@ func TestApiAuth(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
alert := DecisionsListOpts{IPEquals: new(string)}
*alert.IPEquals = "1.2.3.4"
_, resp, err := newcli.Decisions.List(context.Background(), alert)
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
alert := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")}
_, resp, err := newcli.Decisions.List(context.Background(), alert)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
//ko bad token
auth = &APIKeyTransport{
@ -59,25 +57,21 @@ func TestApiAuth(t *testing.T) {
}
newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
_, resp, err = newcli.Decisions.List(context.Background(), alert)
log.Infof("--> %s", err)
if resp.Response.StatusCode != http.StatusForbidden {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
assert.Equal(t, http.StatusForbidden, resp.Response.StatusCode)
cstest.RequireErrorMessage(t, err, "API error: access forbidden")
assert.Contains(t, err.Error(), "API error: access forbidden")
//ko empty token
auth = &APIKeyTransport{}
newcli, err = NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
_, _, err = newcli.Decisions.List(context.Background(), alert)
require.Error(t, err)

View file

@ -74,6 +74,7 @@ func NewClient(config *Config) (*ApiClient, error) {
VersionPrefix: config.VersionPrefix,
UpdateScenario: config.UpdateScenario,
}
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
tlsconfig.RootCAs = CaCertPool
@ -180,8 +181,7 @@ func (e *ErrorResponse) Error() string {
}
func newResponse(r *http.Response) *Response {
response := &Response{Response: r}
return response
return &Response{Response: r}
}
func CheckResponse(r *http.Response) error {
@ -192,7 +192,7 @@ func CheckResponse(r *http.Response) error {
errorResponse := &ErrorResponse{}
data, err := io.ReadAll(r.Body)
if err == nil && data != nil {
if err == nil && len(data)>0 {
err := json.Unmarshal(data, errorResponse)
if err != nil {
return fmt.Errorf("http code %d, invalid body: %w", r.StatusCode, err)

View file

@ -8,19 +8,19 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/version"
)
func TestNewRequestInvalid(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
//missing slash in uri
apiURL, err := url.Parse(urlx)
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -29,9 +29,8 @@ func TestNewRequestInvalid(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
@ -44,17 +43,16 @@ func TestNewRequestInvalid(t *testing.T) {
})
_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
assert.Contains(t, err.Error(), `building request: BaseURL must have a trailing slash, but `)
cstest.RequireErrorContains(t, err, "building request: BaseURL must have a trailing slash, but ")
}
func TestNewRequestTimeout(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
//missing slash in uri
// missing slash in uri
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
@ -63,9 +61,8 @@ func TestNewRequestTimeout(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(2 * time.Second)
@ -75,5 +72,5 @@ func TestNewRequestTimeout(t *testing.T) {
defer cancel()
_, _, err = client.Alerts.List(ctx, AlertsListOpts{})
assert.Contains(t, err.Error(), `performing request: context deadline exceeded`)
cstest.RequireErrorMessage(t, err, "performing request: context deadline exceeded")
}

View file

@ -11,7 +11,9 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/version"
)
@ -20,13 +22,13 @@ import (
- each test will then bind handler for the method(s) they want to try
*/
func setup() (mux *http.ServeMux, serverURL string, teardown func()) {
func setup() (*http.ServeMux, string, func()) {
return setupWithPrefix("v1")
}
func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) {
func setupWithPrefix(urlPrefix string) (*http.ServeMux, string, func()) {
// mux is the HTTP request multiplexer used with the test server.
mux = http.NewServeMux()
mux := http.NewServeMux()
baseURLPath := "/" + urlPrefix
apiHandler := http.NewServeMux()
@ -40,19 +42,16 @@ func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, te
func testMethod(t *testing.T, r *http.Request, want string) {
t.Helper()
if got := r.Method; got != want {
t.Errorf("Request method: %v, want %v", got, want)
}
assert.Equal(t, want, r.Method)
}
func TestNewClientOk(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -60,9 +59,8 @@ func TestNewClientOk(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
@ -75,22 +73,17 @@ func TestNewClientOk(t *testing.T) {
})
_, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{})
if err != nil {
t.Fatalf("test Unable to list alerts : %+v", err)
}
if resp.Response.StatusCode != http.StatusOK {
t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated)
}
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
}
func TestNewClientKo(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -98,9 +91,8 @@ func TestNewClientKo(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
@ -113,36 +105,36 @@ func TestNewClientKo(t *testing.T) {
})
_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
assert.Contains(t, err.Error(), `API error: bad login/password`)
cstest.RequireErrorContains(t, err, `API error: bad login/password`)
log.Printf("err-> %s", err)
}
func TestNewDefaultClient(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewDefaultClient(apiURL, "/v1", "", nil)
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"code": 401, "message" : "brr"}`))
})
_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
assert.Contains(t, err.Error(), `performing request: API error: brr`)
cstest.RequireErrorMessage(t, err, "performing request: API error: brr")
log.Printf("err-> %s", err)
}
func TestNewClientRegisterKO(t *testing.T) {
apiURL, err := url.Parse("http://127.0.0.1:4242/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
_, err = RegisterClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -150,17 +142,18 @@ func TestNewClientRegisterKO(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
}, &http.Client{})
if runtime.GOOS != "windows" {
assert.Contains(t, fmt.Sprintf("%s", err), "dial tcp 127.0.0.1:4242: connect: connection refused")
cstest.RequireErrorContains(t, err, "dial tcp 127.0.0.1:4242: connect: connection refused")
} else {
assert.Contains(t, fmt.Sprintf("%s", err), " No connection could be made because the target machine actively refused it.")
cstest.RequireErrorContains(t, err, " No connection could be made because the target machine actively refused it.")
}
}
func TestNewClientRegisterOK(t *testing.T) {
log.SetLevel(log.TraceLevel)
mux, urlx, teardown := setup()
mux, urlx, teardown := setup()
defer teardown()
/*mock login*/
@ -171,9 +164,8 @@ func TestNewClientRegisterOK(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := RegisterClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -181,17 +173,15 @@ func TestNewClientRegisterOK(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
}, &http.Client{})
if err != nil {
t.Fatalf("while registering client : %s", err)
}
require.NoError(t, err)
log.Printf("->%T", client)
}
func TestNewClientBadAnswer(t *testing.T) {
log.SetLevel(log.TraceLevel)
mux, urlx, teardown := setup()
mux, urlx, teardown := setup()
defer teardown()
/*mock login*/
@ -200,10 +190,10 @@ func TestNewClientBadAnswer(t *testing.T) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`bad`))
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
_, err = RegisterClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -211,5 +201,5 @@ func TestNewClientBadAnswer(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
}, &http.Client{})
assert.Contains(t, fmt.Sprintf("%s", err), `invalid body: invalid character 'b' looking for beginning of value`)
cstest.RequireErrorContains(t, err, "invalid body: invalid character 'b' looking for beginning of value")
}

View file

@ -183,7 +183,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
req = req.WithContext(ctx)
log.Debugf("[URL] %s %s", req.Method, req.URL)
// we dont use client_http Do method because we need the reader and is not provided. We would be forced to use Pipe and goroutine, etc
// we don't use client_http Do method because we need the reader and is not provided.
// We would be forced to use Pipe and goroutine, etc
resp, err := client.Do(req)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
@ -216,6 +217,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
if resp.StatusCode != http.StatusOK {
log.Debugf("Received nok status code %d for blocklist %s", resp.StatusCode, *blocklist.URL)
return nil, false, nil
}

View file

@ -5,13 +5,13 @@ import (
"fmt"
"net/http"
"net/url"
"reflect"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/version"
@ -38,10 +38,9 @@ func TestDecisionsList(t *testing.T) {
//no results
}
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -49,55 +48,32 @@ func TestDecisionsList(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tduration := "3h59m55.756182786s"
torigin := "cscli"
tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"
tscope := "Ip"
ttype := "ban"
tvalue := "1.2.3.4"
expected := &models.GetDecisionsResponse{
&models.Decision{
Duration: &tduration,
Duration: ptr.Of("3h59m55.756182786s"),
ID: 4,
Origin: &torigin,
Scenario: &tscenario,
Scope: &tscope,
Type: &ttype,
Value: &tvalue,
Origin: ptr.Of("cscli"),
Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"),
Scope: ptr.Of("Ip"),
Type: ptr.Of("ban"),
Value: ptr.Of("1.2.3.4"),
},
}
//OK decisions
decisionsFilter := DecisionsListOpts{IPEquals: new(string)}
*decisionsFilter.IPEquals = "1.2.3.4"
// OK decisions
decisionsFilter := DecisionsListOpts{IPEquals: ptr.Of("1.2.3.4")}
decisions, resp, err := newcli.Decisions.List(context.Background(), decisionsFilter)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected)
}
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *decisions)
//Empty return
decisionsFilter = DecisionsListOpts{IPEquals: new(string)}
*decisionsFilter.IPEquals = "1.2.3.5"
decisionsFilter = DecisionsListOpts{IPEquals: ptr.Of("1.2.3.5")}
decisions, resp, err = newcli.Decisions.List(context.Background(), decisionsFilter)
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Empty(t, *decisions)
}
@ -120,6 +96,7 @@ func TestDecisionsStream(t *testing.T) {
}
}
})
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "ixu", r.Header.Get("X-Api-Key"))
testMethod(t, r, http.MethodDelete)
@ -129,9 +106,7 @@ func TestDecisionsStream(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -139,63 +114,38 @@ func TestDecisionsStream(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v1", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tduration := "3h59m55.756182786s"
torigin := "cscli"
tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"
tscope := "Ip"
ttype := "ban"
tvalue := "1.2.3.4"
expected := &models.DecisionsStreamResponse{
New: models.GetDecisionsResponse{
&models.Decision{
Duration: &tduration,
Duration: ptr.Of("3h59m55.756182786s"),
ID: 4,
Origin: &torigin,
Scenario: &tscenario,
Scope: &tscope,
Type: &ttype,
Value: &tvalue,
Origin: ptr.Of("cscli"),
Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"),
Scope: ptr.Of("Ip"),
Type: ptr.Of("ban"),
Value: ptr.Of("1.2.3.4"),
},
},
}
decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true})
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *decisions)
//and second call, we get empty lists
decisions, resp, err = newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: false})
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Empty(t, decisions.New)
assert.Empty(t, decisions.Deleted)
//delete stream
resp, err = newcli.Decisions.StopStream(context.Background())
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
}
func TestDecisionsStreamV3Compatibility(t *testing.T) {
@ -219,9 +169,7 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -229,38 +177,30 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tduration := "3h59m55.756182786s"
torigin := "CAPI"
tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"
tscope := "ip"
ttype := "ban"
tvalue := "1.2.3.4"
tvalue1 := "1.2.3.5"
tscenarioDeleted := "deleted"
tdurationDeleted := "1h"
expected := &models.DecisionsStreamResponse{
New: models.GetDecisionsResponse{
&models.Decision{
Duration: &tduration,
Duration: ptr.Of("3h59m55.756182786s"),
Origin: &torigin,
Scenario: &tscenario,
Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"),
Scope: &tscope,
Type: &ttype,
Value: &tvalue,
Value: ptr.Of("1.2.3.4"),
},
},
Deleted: models.GetDecisionsResponse{
&models.Decision{
Duration: &tdurationDeleted,
Duration: ptr.Of("1h"),
Origin: &torigin,
Scenario: &tscenarioDeleted,
Scenario: ptr.Of("deleted"),
Scope: &tscope,
Type: &ttype,
Value: &tvalue1,
Value: ptr.Of("1.2.3.5"),
},
},
}
@ -268,18 +208,8 @@ func TestDecisionsStreamV3Compatibility(t *testing.T) {
// GetStream is supposed to consume v3 payload and return v2 response
decisions, resp, err := newcli.Decisions.GetStream(context.Background(), DecisionsStreamOpts{Startup: true})
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *decisions)
}
func TestDecisionsStreamV3(t *testing.T) {
@ -300,9 +230,7 @@ func TestDecisionsStreamV3(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -310,30 +238,19 @@ func TestDecisionsStreamV3(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tduration := "3h59m55.756182786s"
tscenario := "manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"
tscope := "ip"
tvalue := "1.2.3.4"
tvalue1 := "1.2.3.5"
tdurationBlocklist := "24h"
tnameBlocklist := "blocklist1"
tremediationBlocklist := "ban"
tscopeBlocklist := "ip"
turlBlocklist := "/v3/blocklist"
expected := &modelscapi.GetDecisionsStreamResponse{
New: modelscapi.GetDecisionsStreamResponseNew{
&modelscapi.GetDecisionsStreamResponseNewItem{
Decisions: []*modelscapi.GetDecisionsStreamResponseNewItemDecisionsItems0{
{
Duration: &tduration,
Value: &tvalue,
Duration: ptr.Of("3h59m55.756182786s"),
Value: ptr.Of("1.2.3.4"),
},
},
Scenario: &tscenario,
Scenario: ptr.Of("manual 'ban' from '82929df7ee394b73b81252fe3b4e50203yaT2u6nXiaN7Ix9'"),
Scope: &tscope,
},
},
@ -341,18 +258,18 @@ func TestDecisionsStreamV3(t *testing.T) {
&modelscapi.GetDecisionsStreamResponseDeletedItem{
Scope: &tscope,
Decisions: []string{
tvalue1,
"1.2.3.5",
},
},
},
Links: &modelscapi.GetDecisionsStreamResponseLinks{
Blocklists: []*modelscapi.BlocklistLink{
{
Duration: &tdurationBlocklist,
Name: &tnameBlocklist,
Remediation: &tremediationBlocklist,
Scope: &tscopeBlocklist,
URL: &turlBlocklist,
Duration: ptr.Of("24h"),
Name: ptr.Of("blocklist1"),
Remediation: ptr.Of("ban"),
Scope: ptr.Of("ip"),
URL: ptr.Of("/v3/blocklist"),
},
},
},
@ -361,18 +278,8 @@ func TestDecisionsStreamV3(t *testing.T) {
// GetStream is supposed to consume v3 payload and return v2 response
decisions, resp, err := newcli.Decisions.GetStreamV3(context.Background(), DecisionsStreamOpts{Startup: true})
require.NoError(t, err)
if resp.Response.StatusCode != http.StatusOK {
t.Errorf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusOK)
}
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(*decisions, *expected) {
t.Fatalf("returned %+v, want %+v", resp, expected)
}
assert.Equal(t, http.StatusOK, resp.Response.StatusCode)
assert.Equal(t, *expected, *decisions)
}
func TestDecisionsFromBlocklist(t *testing.T) {
@ -383,10 +290,13 @@ func TestDecisionsFromBlocklist(t *testing.T) {
mux.HandleFunc("/blocklist", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, http.MethodGet)
if r.Header.Get("If-Modified-Since") == "Sun, 01 Jan 2023 01:01:01 GMT" {
w.WriteHeader(http.StatusNotModified)
return
}
if r.Method == http.MethodGet {
w.WriteHeader(http.StatusOK)
w.Write([]byte("1.2.3.4\r\n1.2.3.5"))
@ -394,9 +304,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
})
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
//ok answer
auth := &APIKeyTransport{
@ -404,12 +312,8 @@ func TestDecisionsFromBlocklist(t *testing.T) {
}
newcli, err := NewDefaultClient(apiURL, "v3", "toto", auth.Client())
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
tvalue1 := "1.2.3.4"
tvalue2 := "1.2.3.5"
tdurationBlocklist := "24h"
tnameBlocklist := "blocklist1"
tremediationBlocklist := "ban"
@ -419,7 +323,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
expected := []*models.Decision{
{
Duration: &tdurationBlocklist,
Value: &tvalue1,
Value: ptr.Of("1.2.3.4"),
Scenario: &tnameBlocklist,
Scope: &tscopeBlocklist,
Type: &tremediationBlocklist,
@ -427,7 +331,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
},
{
Duration: &tdurationBlocklist,
Value: &tvalue2,
Value: ptr.Of("1.2.3.5"),
Scenario: &tnameBlocklist,
Scope: &tscopeBlocklist,
Type: &tremediationBlocklist,
@ -450,13 +354,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
log.Infof("expected : %s, %s, %s, %s, %s", *expected[0].Value, *expected[0].Duration, *expected[0].Scenario, *expected[0].Scope, *expected[0].Type)
log.Infof("decisions: %s, %s, %s, %s, %s", *decisions[1].Value, *decisions[1].Duration, *decisions[1].Scenario, *decisions[1].Scope, *decisions[1].Type)
if err != nil {
t.Fatalf("new api client: %s", err)
}
if !reflect.DeepEqual(decisions, expected) {
t.Fatalf("returned %+v, want %+v", decisions, expected)
}
assert.Equal(t, expected, decisions)
// test cache control
_, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
@ -466,8 +364,10 @@ func TestDecisionsFromBlocklist(t *testing.T) {
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
}, ptr.Of("Sun, 01 Jan 2023 01:01:01 GMT"))
require.NoError(t, err)
assert.False(t, isModified)
_, isModified, err = newcli.Decisions.GetDecisionsFromBlocklist(context.Background(), &modelscapi.BlocklistLink{
URL: &turlBlocklist,
Scope: &tscopeBlocklist,
@ -475,6 +375,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
}, ptr.Of("Mon, 02 Jan 2023 01:01:01 GMT"))
require.NoError(t, err)
assert.True(t, isModified)
}
@ -485,6 +386,7 @@ func TestDeleteDecisions(t *testing.T) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
mux.HandleFunc("/decisions", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "DELETE")
assert.Equal(t, "ip=1.2.3.4", r.URL.RawQuery)
@ -492,11 +394,12 @@ func TestDeleteDecisions(t *testing.T) {
w.Write([]byte(`{"nbDeleted":"1"}`))
//w.Write([]byte(`{"message":"0 deleted alerts"}`))
})
log.Printf("URL is %s", urlx)
apiURL, err := url.Parse(urlx + "/")
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
require.NoError(t, err)
client, err := NewClient(&Config{
MachineID: "test_login",
Password: "test_password",
@ -504,18 +407,13 @@ func TestDeleteDecisions(t *testing.T) {
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
require.NoError(t, err)
filters := DecisionsDeleteOpts{IPEquals: new(string)}
*filters.IPEquals = "1.2.3.4"
deleted, _, err := client.Decisions.Delete(context.Background(), filters)
if err != nil {
t.Fatalf("unexpected err : %s", err)
}
deleted, _, err := client.Decisions.Delete(context.Background(), filters)
require.NoError(t, err)
assert.Equal(t, "1", deleted.NbDeleted)
defer teardown()
@ -530,22 +428,23 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
ScenariosContaining string
ScenariosNotContaining string
}
tests := []struct {
name string
fields fields
want string
wantErr bool
name string
fields fields
expected string
expectedErr string
}{
{
name: "no filter",
want: baseURLString + "?",
name: "no filter",
expected: baseURLString + "?",
},
{
name: "startup=true",
fields: fields{
Startup: true,
},
want: baseURLString + "?startup=true",
expected: baseURLString + "?startup=true",
},
{
name: "set all params",
@ -555,7 +454,7 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
ScenariosContaining: "ssh",
ScenariosNotContaining: "bf",
},
want: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true",
expected: baseURLString + "?scenarios_containing=ssh&scenarios_not_containing=bf&scopes=ip%2Crange&startup=true",
},
}
@ -568,25 +467,20 @@ func TestDecisionsStreamOpts_addQueryParamsToURL(t *testing.T) {
ScenariosContaining: tt.fields.ScenariosContaining,
ScenariosNotContaining: tt.fields.ScenariosNotContaining,
}
got, err := o.addQueryParamsToURL(baseURLString)
if (err != nil) != tt.wantErr {
t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() error = %v, wantErr %v", err, tt.wantErr)
cstest.RequireErrorContains(t, err, tt.expectedErr)
if tt.expectedErr != "" {
return
}
gotURL, err := url.Parse(got)
if err != nil {
t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err)
}
require.NoError(t, err)
expectedURL, err := url.Parse(tt.want)
if err != nil {
t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() got error while parsing URL: %s", err)
}
expectedURL, err := url.Parse(tt.expected)
require.NoError(t, err)
if *gotURL != *expectedURL {
t.Errorf("DecisionsStreamOpts.addQueryParamsToURL() = %v, want %v", *gotURL, *expectedURL)
}
assert.Equal(t, *expectedURL, *gotURL)
})
}
}

View file

@ -38,11 +38,13 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
select {
case <-hbTimer.C:
log.Debug("heartbeat: sending heartbeat")
ok, resp, err := h.Ping(ctx)
if err != nil {
log.Errorf("heartbeat error : %s", err)
continue
}
resp.Response.Body.Close()
if resp.Response.StatusCode != http.StatusOK {
log.Errorf("heartbeat unexpected return code : %d", resp.Response.StatusCode)

View file

@ -10,8 +10,8 @@ import (
"testing"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin"
@ -22,21 +22,14 @@ type LAPI struct {
router *gin.Engine
loginResp models.WatcherAuthResponse
bouncerKey string
t *testing.T
DBConfig *csconfig.DatabaseCfg
}
func SetupLAPITest(t *testing.T) LAPI {
t.Helper()
router, loginResp, config, err := InitMachineTest(t)
if err != nil {
t.Fatal(err)
}
router, loginResp, config := InitMachineTest(t)
APIKey, err := CreateTestBouncer(config.API.Server.DbConfig)
if err != nil {
t.Fatal(err)
}
APIKey := CreateTestBouncer(t, config.API.Server.DbConfig)
return LAPI{
router: router,
@ -46,24 +39,23 @@ func SetupLAPITest(t *testing.T) LAPI {
}
}
func (l *LAPI) InsertAlertFromFile(path string) *httptest.ResponseRecorder {
alertReader := GetAlertReaderFromFile(path)
return l.RecordResponse(http.MethodPost, "/v1/alerts", alertReader, "password")
func (l *LAPI) InsertAlertFromFile(t *testing.T, path string) *httptest.ResponseRecorder {
alertReader := GetAlertReaderFromFile(t, path)
return l.RecordResponse(t, http.MethodPost, "/v1/alerts", alertReader, "password")
}
func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder {
func (l *LAPI) RecordResponse(t *testing.T, verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder {
w := httptest.NewRecorder()
req, err := http.NewRequest(verb, url, body)
if err != nil {
l.t.Fatal(err)
}
require.NoError(t, err)
if authType == "apikey" {
switch authType {
case "apikey":
req.Header.Add("X-Api-Key", l.bouncerKey)
} else if authType == "password" {
case "password":
AddAuthHeaders(req, l.loginResp)
} else {
l.t.Fatal("auth type not supported")
default:
t.Fatal("auth type not supported")
}
l.router.ServeHTTP(w, req)
@ -71,29 +63,16 @@ func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, aut
return w
}
func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config, error) {
router, config, err := NewAPITest(t)
if err != nil {
return nil, models.WatcherAuthResponse{}, config, fmt.Errorf("unable to run local API: %s", err)
}
func InitMachineTest(t *testing.T) (*gin.Engine, models.WatcherAuthResponse, csconfig.Config) {
router, config := NewAPITest(t)
loginResp := LoginToTestAPI(t, router, config)
loginResp, err := LoginToTestAPI(router, config)
if err != nil {
return nil, models.WatcherAuthResponse{}, config, err
}
return router, loginResp, config, nil
return router, loginResp, config
}
func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherAuthResponse, error) {
body, err := CreateTestMachine(router)
if err != nil {
return models.WatcherAuthResponse{}, err
}
err = ValidateMachine("test", config.API.Server.DbConfig)
if err != nil {
log.Fatalln(err)
}
func LoginToTestAPI(t *testing.T, router *gin.Engine, config csconfig.Config) models.WatcherAuthResponse {
body := CreateTestMachine(t, router)
ValidateMachine(t, "test", config.API.Server.DbConfig)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(body))
@ -101,12 +80,10 @@ func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherA
router.ServeHTTP(w, req)
loginResp := models.WatcherAuthResponse{}
err = json.NewDecoder(w.Body).Decode(&loginResp)
if err != nil {
return models.WatcherAuthResponse{}, err
}
err := json.NewDecoder(w.Body).Decode(&loginResp)
require.NoError(t, err)
return loginResp, nil
return loginResp
}
func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthResponse) {
@ -116,17 +93,17 @@ func AddAuthHeaders(request *http.Request, authResponse models.WatcherAuthRespon
func TestSimulatedAlert(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_minibulk+simul.json")
alertContent := GetAlertReaderFromFile("./tests/alert_minibulk+simul.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk+simul.json")
alertContent := GetAlertReaderFromFile(t, "./tests/alert_minibulk+simul.json")
//exclude decision in simulation mode
w := lapi.RecordResponse("GET", "/v1/alerts?simulated=false", alertContent, "password")
w := lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", alertContent, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `)
assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `)
//include decision in simulation mode
w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", alertContent, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", alertContent, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `)
@ -136,35 +113,29 @@ func TestCreateAlert(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Alert with invalid format
w := lapi.RecordResponse(http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password")
w := lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", strings.NewReader("test"), "password")
assert.Equal(t, 400, w.Code)
assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String())
assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Create Alert with invalid input
alertContent := GetAlertReaderFromFile("./tests/invalidAlert_sample.json")
alertContent := GetAlertReaderFromFile(t, "./tests/invalidAlert_sample.json")
w = lapi.RecordResponse(http.MethodPost, "/v1/alerts", alertContent, "password")
w = lapi.RecordResponse(t, http.MethodPost, "/v1/alerts", alertContent, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"validation failure list:\\n0.scenario in body is required\\n0.scenario_hash in body is required\\n0.scenario_version in body is required\\n0.simulated in body is required\\n0.source in body is required\"}", w.Body.String())
assert.Equal(t, `{"message":"validation failure list:\n0.scenario in body is required\n0.scenario_hash in body is required\n0.scenario_version in body is required\n0.simulated in body is required\n0.source in body is required"}`, w.Body.String())
// Create Valid Alert
w = lapi.InsertAlertFromFile("./tests/alert_sample.json")
w = lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assert.Equal(t, 201, w.Code)
assert.Equal(t, "[\"1\"]", w.Body.String())
assert.Equal(t, `["1"]`, w.Body.String())
}
func TestCreateAlertChannels(t *testing.T) {
apiServer, config, err := NewAPIServer(t)
if err != nil {
log.Fatalln(err)
}
apiServer, config := NewAPIServer(t)
apiServer.controller.PluginChannel = make(chan csplugin.ProfileAlert)
apiServer.InitController()
loginResp, err := LoginToTestAPI(apiServer.router, config)
if err != nil {
log.Fatalln(err)
}
loginResp := LoginToTestAPI(t, apiServer.router, config)
lapi := LAPI{router: apiServer.router, loginResp: loginResp}
var (
@ -180,7 +151,7 @@ func TestCreateAlertChannels(t *testing.T) {
wg.Done()
}()
go lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json")
go lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json")
wg.Wait()
assert.Len(t, pd.Alert.Decisions, 1)
apiServer.Close()
@ -188,18 +159,18 @@ func TestCreateAlertChannels(t *testing.T) {
func TestAlertListFilters(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_ssh-bf.json")
alertContent := GetAlertReaderFromFile("./tests/alert_ssh-bf.json")
lapi.InsertAlertFromFile(t, "./tests/alert_ssh-bf.json")
alertContent := GetAlertReaderFromFile(t, "./tests/alert_ssh-bf.json")
//bad filter
w := lapi.RecordResponse("GET", "/v1/alerts?test=test", alertContent, "password")
w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", alertContent, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String())
assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
//get without filters
w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password")
assert.Equal(t, 200, w.Code)
//check alert and decision
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
@ -207,149 +178,149 @@ func TestAlertListFilters(t *testing.T) {
//test decision_type filter (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ban", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ban", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test decision_type filter (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ratata", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?decision_type=ratata", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test scope (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?scope=Ip", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=Ip", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test scope (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?scope=rarara", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?scope=rarara", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test scenario (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test scenario (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test ip (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test ip (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test ip (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?ip=gruueq", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?ip=gruueq", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String())
//test range (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test range
w = lapi.RecordResponse("GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test range (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?range=ratata", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?range=ratata", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String())
//test since (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?since=1h", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1h", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test since (ok but yields no results)
w = lapi.RecordResponse("GET", "/v1/alerts?since=1ns", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1ns", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test since (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?since=1zuzu", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?since=1zuzu", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`)
//test until (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?until=1ns", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1ns", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test until (ok but no return)
w = lapi.RecordResponse("GET", "/v1/alerts?until=1m", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1m", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test until (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?until=1zuzu", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?until=1zuzu", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`)
//test simulated (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=true", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test simulated (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?simulated=false", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?simulated=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test has active decision
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=true", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=true", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test has active decision
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=false", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=false", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String())
//test has active decision (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String())
}
@ -357,32 +328,32 @@ func TestAlertListFilters(t *testing.T) {
func TestAlertBulkInsert(t *testing.T) {
lapi := SetupLAPITest(t)
//insert a bulk of 20 alerts to trigger bulk insert
lapi.InsertAlertFromFile("./tests/alert_bulk.json")
alertContent := GetAlertReaderFromFile("./tests/alert_bulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_bulk.json")
alertContent := GetAlertReaderFromFile(t, "./tests/alert_bulk.json")
w := lapi.RecordResponse("GET", "/v1/alerts", alertContent, "password")
w := lapi.RecordResponse(t, "GET", "/v1/alerts", alertContent, "password")
assert.Equal(t, 200, w.Code)
}
func TestListAlert(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// List Alert with invalid filter
w := lapi.RecordResponse("GET", "/v1/alerts?test=test", emptyBody, "password")
w := lapi.RecordResponse(t, "GET", "/v1/alerts?test=test", emptyBody, "password")
assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String())
assert.Equal(t, `{"message":"Filter parameter 'test' is unknown (=test): invalid filter"}`, w.Body.String())
// List Alert
w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password")
w = lapi.RecordResponse(t, "GET", "/v1/alerts", emptyBody, "password")
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "crowdsecurity/test")
}
func TestCreateAlertErrors(t *testing.T) {
lapi := SetupLAPITest(t)
alertContent := GetAlertReaderFromFile("./tests/alert_sample.json")
alertContent := GetAlertReaderFromFile(t, "./tests/alert_sample.json")
//test invalid bearer
w := httptest.NewRecorder()
@ -403,7 +374,7 @@ func TestCreateAlertErrors(t *testing.T) {
func TestDeleteAlert(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Fail Delete Alert
w := httptest.NewRecorder()
@ -426,7 +397,7 @@ func TestDeleteAlert(t *testing.T) {
func TestDeleteAlertByID(t *testing.T) {
lapi := SetupLAPITest(t)
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Fail Delete Alert
w := httptest.NewRecorder()
@ -454,25 +425,18 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
cfg.API.Server.TrustedIPs = []string{"1.2.3.4", "1.2.4.0/24"}
cfg.API.Server.ListenURI = "::8080"
server, err := NewServer(cfg.API.Server)
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
err = server.InitController()
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
router, err := server.Router()
if err != nil {
log.Fatal(err)
}
loginResp, err := LoginToTestAPI(router, cfg)
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
loginResp := LoginToTestAPI(t, router, cfg)
lapi := LAPI{
router: router,
loginResp: loginResp,
t: t,
}
assertAlertDeleteFailedFromIP := func(ip string) {
@ -498,17 +462,17 @@ func TestDeleteAlertTrustedIPS(t *testing.T) {
assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
}
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeleteFailedFromIP("4.3.2.1")
assertAlertDeletedFromIP("1.2.3.4")
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeletedFromIP("1.2.4.0")
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeletedFromIP("1.2.4.1")
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeletedFromIP("1.2.4.255")
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
assertAlertDeletedFromIP("127.0.0.1")
}

View file

@ -6,20 +6,14 @@ import (
"strings"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func TestAPIKey(t *testing.T) {
router, config, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITest(t)
APIKey := CreateTestBouncer(t, config.API.Server.DbConfig)
APIKey, err := CreateTestBouncer(config.API.Server.DbConfig)
if err != nil {
log.Fatal(err)
}
// Login with empty token
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/v1/decisions", strings.NewReader(""))
@ -27,7 +21,7 @@ func TestAPIKey(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String())
assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String())
// Login with invalid token
w = httptest.NewRecorder()
@ -37,7 +31,7 @@ func TestAPIKey(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, "{\"message\":\"access forbidden\"}", w.Body.String())
assert.Equal(t, `{"message":"access forbidden"}`, w.Body.String())
// Login with valid token
w = httptest.NewRecorder()

View file

@ -36,6 +36,7 @@ import (
func getDBClient(t *testing.T) *database.Client {
t.Helper()
dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err)
dbClient, err := database.NewClient(&csconfig.DatabaseCfg{
@ -72,8 +73,9 @@ func getAPIC(t *testing.T) *apic {
}
}
func absDiff(a int, b int) (c int) {
if c = a - b; c < 0 {
func absDiff(a int, b int) int {
c := a - b
if c < 0 {
return -1 * c
}
@ -185,6 +187,7 @@ func TestAPICFetchScenariosListFromDB(t *testing.T) {
func TestNewAPIC(t *testing.T) {
var testConfig *csconfig.OnlineApiClientCfg
setConfig := func() {
testConfig = &csconfig.OnlineApiClientCfg{
Credentials: &csconfig.ApiCredentialsCfg{
@ -199,6 +202,7 @@ func TestNewAPIC(t *testing.T) {
dbClient *database.Client
consoleConfig *csconfig.ConsoleConfig
}
tests := []struct {
name string
args args
@ -374,7 +378,6 @@ func TestAPICGetMetrics(t *testing.T) {
assert.Equal(t, tc.expectedMetric.Bouncers, foundMetrics.Bouncers)
assert.Equal(t, tc.expectedMetric.Machines, foundMetrics.Machines)
})
}
}
@ -403,6 +406,7 @@ func TestCreateAlertsForDecision(t *testing.T) {
type args struct {
decisions []*models.Decision
}
tests := []struct {
name string
args args
@ -489,6 +493,7 @@ func TestFillAlertsWithDecisions(t *testing.T) {
alerts []*models.Alert
decisions []*models.Decision
}
tests := []struct {
name string
args args
@ -544,26 +549,18 @@ func TestAPICWhitelists(t *testing.T) {
api := getAPIC(t)
//one whitelist on IP, one on CIDR
api.whitelists = &csconfig.CapiWhitelist{}
ipwl1 := "9.2.3.4"
ip := net.ParseIP(ipwl1)
api.whitelists.Ips = append(api.whitelists.Ips, ip)
ipwl1 = "7.2.3.4"
ip = net.ParseIP(ipwl1)
api.whitelists.Ips = append(api.whitelists.Ips, ip)
cidrwl1 := "13.2.3.0/24"
_, tnet, err := net.ParseCIDR(cidrwl1)
if err != nil {
t.Fatalf("unable to parse cidr : %s", err)
}
api.whitelists.Ips = append(api.whitelists.Ips, net.ParseIP("9.2.3.4"), net.ParseIP("7.2.3.4"))
_, tnet, err := net.ParseCIDR("13.2.3.0/24")
require.NoError(t, err)
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
cidrwl1 = "11.2.3.0/24"
_, tnet, err = net.ParseCIDR(cidrwl1)
if err != nil {
t.Fatalf("unable to parse cidr : %s", err)
}
_, tnet, err = net.ParseCIDR("11.2.3.0/24")
require.NoError(t, err)
api.whitelists.Cidrs = append(api.whitelists.Cidrs, tnet)
api.dbClient.Ent.Decision.Create().
SetOrigin(types.CAPIOrigin).
SetType("ban").
@ -663,12 +660,15 @@ func TestAPICWhitelists(t *testing.T) {
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder(
200, "1.2.3.6",
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder(
200, "1.2.3.7",
))
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -801,12 +801,15 @@ func TestAPICPullTop(t *testing.T) {
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", httpmock.NewStringResponder(
200, "1.2.3.6",
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist2", httpmock.NewStringResponder(
200, "1.2.3.7",
))
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -828,7 +831,8 @@ func TestAPICPullTop(t *testing.T) {
alerts := api.dbClient.Ent.Alert.Query().AllX(context.Background())
validDecisions := api.dbClient.Ent.Decision.Query().Where(
decision.UntilGT(time.Now())).
AllX(context.Background())
AllX(context.Background(),
)
decisionScenarioFreq := make(map[string]int)
alertScenario := make(map[string]int)
@ -858,6 +862,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/api/decisions/stream", httpmock.NewBytesResponder(
200, jsonMarshalX(
modelscapi.GetDecisionsStreamResponse{
@ -887,10 +892,12 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(200, "1.2.3.4"), nil
})
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -916,6 +923,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
assert.NotEqual(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(304, ""), nil
})
err = api.PullTop(false)
require.NoError(t, err)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(blocklistConfigItemName)
@ -928,6 +936,7 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()
// create a decision about to expire. It should force fetch
alertInstance := api.dbClient.Ent.Alert.
Create().
@ -975,10 +984,12 @@ func TestAPICPullTopBLCacheForceCall(t *testing.T) {
},
),
))
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(304, ""), nil
})
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -1005,6 +1016,7 @@ func TestAPICPullBlocklistCall(t *testing.T) {
assert.Equal(t, "", req.Header.Get("If-Modified-Since"))
return httpmock.NewStringResponse(200, "1.2.3.4"), nil
})
url, err := url.ParseRequestURI("http://api.crowdsec.net/")
require.NoError(t, err)
@ -1073,6 +1085,7 @@ func TestAPICPush(t *testing.T) {
Source: &models.Source{},
}
}
return alerts
}(),
},

View file

@ -307,6 +307,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
log.Errorf("capi push: %s", err)
return err
}
return nil
})
@ -315,6 +316,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
log.Errorf("capi pull: %s", err)
return err
}
return nil
})
@ -328,6 +330,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
log.Errorf("papi pull: %s", err)
return err
}
return nil
})
@ -336,6 +339,7 @@ func (s *APIServer) Run(apiReady chan bool) error {
log.Errorf("capi decisions sync: %s", err)
return err
}
return nil
})
} else {

View file

@ -13,11 +13,12 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/cstest"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/version"
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
@ -63,13 +64,14 @@ func LoadTestConfig(t *testing.T) csconfig.Config {
ShareCustomScenarios: new(bool),
},
}
apiConfig := csconfig.APICfg{
Server: &apiServerConfig,
}
config.API = &apiConfig
if err := config.API.Server.LoadProfiles(); err != nil {
log.Fatalf("failed to load profiles: %s", err)
}
err := config.API.Server.LoadProfiles()
require.NoError(t, err)
return config
}
@ -106,110 +108,89 @@ func LoadTestConfigForwardedFor(t *testing.T) csconfig.Config {
Server: &apiServerConfig,
}
config.API = &apiConfig
if err := config.API.Server.LoadProfiles(); err != nil {
log.Fatalf("failed to load profiles: %s", err)
}
err := config.API.Server.LoadProfiles()
require.NoError(t, err)
return config
}
func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config, error) {
func NewAPIServer(t *testing.T) (*APIServer, csconfig.Config) {
config := LoadTestConfig(t)
os.Remove("./ent")
apiServer, err := NewServer(config.API.Server)
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
log.Printf("Creating new API server")
gin.SetMode(gin.TestMode)
return apiServer, config, nil
return apiServer, config
}
func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config, error) {
apiServer, config, err := NewAPIServer(t)
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
err = apiServer.InitController()
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
func NewAPITest(t *testing.T) (*gin.Engine, csconfig.Config) {
apiServer, config := NewAPIServer(t)
err := apiServer.InitController()
require.NoError(t, err)
router, err := apiServer.Router()
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
return router, config, nil
return router, config
}
func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config, error) {
func NewAPITestForwardedFor(t *testing.T) (*gin.Engine, csconfig.Config) {
config := LoadTestConfigForwardedFor(t)
os.Remove("./ent")
apiServer, err := NewServer(config.API.Server)
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
err = apiServer.InitController()
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
log.Printf("Creating new API server")
gin.SetMode(gin.TestMode)
router, err := apiServer.Router()
if err != nil {
return nil, config, fmt.Errorf("unable to run local API: %s", err)
}
require.NoError(t, err)
return router, config, nil
return router, config
}
func ValidateMachine(machineID string, config *csconfig.DatabaseCfg) error {
func ValidateMachine(t *testing.T, machineID string, config *csconfig.DatabaseCfg) {
dbClient, err := database.NewClient(config)
if err != nil {
return fmt.Errorf("unable to create new database client: %s", err)
}
require.NoError(t, err)
if err := dbClient.ValidateMachine(machineID); err != nil {
return fmt.Errorf("unable to validate machine: %s", err)
}
return nil
err = dbClient.ValidateMachine(machineID)
require.NoError(t, err)
}
func GetMachineIP(machineID string, config *csconfig.DatabaseCfg) (string, error) {
func GetMachineIP(t *testing.T, machineID string, config *csconfig.DatabaseCfg) string {
dbClient, err := database.NewClient(config)
if err != nil {
return "", fmt.Errorf("unable to create new database client: %s", err)
}
require.NoError(t, err)
machines, err := dbClient.ListMachines()
if err != nil {
return "", fmt.Errorf("Unable to list machines: %s", err)
}
require.NoError(t, err)
for _, machine := range machines {
if machine.MachineId == machineID {
return machine.IpAddress, nil
return machine.IpAddress
}
}
return "", nil
return ""
}
func GetAlertReaderFromFile(path string) *strings.Reader {
func GetAlertReaderFromFile(t *testing.T, path string) *strings.Reader {
alertContentBytes, err := os.ReadFile(path)
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
alerts := make([]*models.Alert, 0)
if err = json.Unmarshal(alertContentBytes, &alerts); err != nil {
log.Fatal(err)
}
err = json.Unmarshal(alertContentBytes, &alerts)
require.NoError(t, err)
for _, alert := range alerts {
*alert.StartAt = time.Now().UTC().Format(time.RFC3339)
@ -217,74 +198,57 @@ func GetAlertReaderFromFile(path string) *strings.Reader {
}
alertContent, err := json.Marshal(alerts)
if err != nil {
log.Fatal(err)
}
require.NoError(t, err)
return strings.NewReader(string(alertContent))
}
func readDecisionsGetResp(resp *httptest.ResponseRecorder) ([]*models.Decision, int, error) {
func readDecisionsGetResp(t *testing.T, resp *httptest.ResponseRecorder) ([]*models.Decision, int) {
var response []*models.Decision
if resp == nil {
return nil, 0, errors.New("response is nil")
}
err := json.Unmarshal(resp.Body.Bytes(), &response)
if err != nil {
return nil, resp.Code, err
}
require.NotNil(t, resp)
return response, resp.Code, nil
err := json.Unmarshal(resp.Body.Bytes(), &response)
require.NoError(t, err)
return response, resp.Code
}
func readDecisionsErrorResp(resp *httptest.ResponseRecorder) (map[string]string, int, error) {
func readDecisionsErrorResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string]string, int) {
var response map[string]string
if resp == nil {
return nil, 0, errors.New("response is nil")
}
err := json.Unmarshal(resp.Body.Bytes(), &response)
if err != nil {
return nil, resp.Code, err
}
require.NotNil(t, resp)
return response, resp.Code, nil
err := json.Unmarshal(resp.Body.Bytes(), &response)
require.NoError(t, err)
return response, resp.Code
}
func readDecisionsDeleteResp(resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int, error) {
func readDecisionsDeleteResp(t *testing.T, resp *httptest.ResponseRecorder) (*models.DeleteDecisionResponse, int) {
var response models.DeleteDecisionResponse
if resp == nil {
return nil, 0, errors.New("response is nil")
}
require.NotNil(t, resp)
err := json.Unmarshal(resp.Body.Bytes(), &response)
if err != nil {
return nil, resp.Code, err
}
require.NoError(t, err)
return &response, resp.Code, nil
return &response, resp.Code
}
func readDecisionsStreamResp(resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int, error) {
func readDecisionsStreamResp(t *testing.T, resp *httptest.ResponseRecorder) (map[string][]*models.Decision, int) {
response := make(map[string][]*models.Decision)
if resp == nil {
return nil, 0, errors.New("response is nil")
}
require.NotNil(t, resp)
err := json.Unmarshal(resp.Body.Bytes(), &response)
if err != nil {
return nil, resp.Code, err
}
require.NoError(t, err)
return response, resp.Code, nil
return response, resp.Code
}
func CreateTestMachine(router *gin.Engine) (string, error) {
func CreateTestMachine(t *testing.T, router *gin.Engine) string {
b, err := json.Marshal(MachineTest)
if err != nil {
return "", fmt.Errorf("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w := httptest.NewRecorder()
@ -292,26 +256,20 @@ func CreateTestMachine(router *gin.Engine) (string, error) {
req.Header.Set("User-Agent", UserAgent)
router.ServeHTTP(w, req)
return body, nil
return body
}
func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) {
func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
dbClient, err := database.NewClient(config)
if err != nil {
log.Fatalf("unable to create new database client: %s", err)
}
require.NoError(t, err)
apiKey, err := middlewares.GenerateAPIKey(keyLength)
if err != nil {
return "", fmt.Errorf("unable to generate api key: %s", err)
}
require.NoError(t, err)
_, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
if err != nil {
return "", fmt.Errorf("unable to create blocker: %s", err)
}
require.NoError(t, err)
return apiKey, nil
return apiKey
}
func TestWithWrongDBConfig(t *testing.T) {
@ -334,10 +292,7 @@ func TestWithWrongFlushConfig(t *testing.T) {
}
func TestUnknownPath(t *testing.T) {
router, _, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, _ := NewAPITest(t)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test", nil)
@ -384,24 +339,17 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
LogDir: tempDir,
DbConfig: &dbconfig,
}
lvl := log.DebugLevel
expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
expectedLines := []string{"/test42"}
cfg.LogLevel = &lvl
cfg.LogLevel = ptr.Of(log.DebugLevel)
// Configure logging
if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil {
t.Fatal(err)
}
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
require.NoError(t, err)
api, err := NewServer(&cfg)
if err != nil {
t.Fatalf("failed to create api : %s", err)
}
if api == nil {
t.Fatalf("failed to create api #2 is nbill")
}
require.NoError(t, err)
require.NotNil(t, api)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test42", nil)
@ -413,14 +361,10 @@ func TestLoggingDebugToFileConfig(t *testing.T) {
//check file content
data, err := os.ReadFile(expectedFile)
if err != nil {
t.Fatalf("failed to read file : %s", err)
}
require.NoError(t, err)
for _, expectedStr := range expectedLines {
if !strings.Contains(string(data), expectedStr) {
t.Fatalf("expected %s in %s", expectedStr, string(data))
}
assert.Contains(t, string(data), expectedStr)
}
}
@ -446,35 +390,29 @@ func TestLoggingErrorToFileConfig(t *testing.T) {
LogDir: tempDir,
DbConfig: &dbconfig,
}
lvl := log.ErrorLevel
expectedFile := fmt.Sprintf("%s/crowdsec_api.log", tempDir)
cfg.LogLevel = &lvl
cfg.LogLevel = ptr.Of(log.ErrorLevel)
// Configure logging
if err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false); err != nil {
t.Fatal(err)
}
api, err := NewServer(&cfg)
if err != nil {
t.Fatalf("failed to create api : %s", err)
}
err := types.SetDefaultLoggerConfig(cfg.LogMedia, cfg.LogDir, *cfg.LogLevel, cfg.LogMaxSize, cfg.LogMaxFiles, cfg.LogMaxAge, cfg.CompressLogs, false)
require.NoError(t, err)
if api == nil {
t.Fatalf("failed to create api #2 is nbill")
}
api, err := NewServer(&cfg)
require.NoError(t, err)
require.NotNil(t, api)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/test42", nil)
req.Header.Set("User-Agent", UserAgent)
api.router.ServeHTTP(w, req)
assert.Equal(t, 404, w.Code)
assert.Equal(t, http.StatusNotFound, w.Code)
//wait for the request to happen
time.Sleep(500 * time.Millisecond)
//check file content
x, err := os.ReadFile(expectedFile)
if err == nil && len(x) > 0 {
t.Fatalf("file should be empty, got '%s'", x)
if err == nil {
require.Empty(t, x)
}
os.Remove("./crowdsec.log")

View file

@ -55,6 +55,7 @@ func serveHealth() http.HandlerFunc {
// no caching required
health.WithDisabledCache(),
)
return health.NewHandler(checker)
}
@ -76,6 +77,7 @@ func (c *Controller) NewV1() error {
if err != nil {
return err
}
c.Router.GET("/health", gin.WrapF(serveHealth()))
c.Router.Use(v1.PrometheusMiddleware())
c.Router.HandleMethodNotAllowed = true
@ -104,7 +106,6 @@ func (c *Controller) NewV1() error {
jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions)
jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById)
jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat)
}
apiKeyAuth := groupV1.Group("")

View file

@ -22,7 +22,6 @@ import (
)
func FormatOneAlert(alert *ent.Alert) *models.Alert {
var outputAlert models.Alert
startAt := alert.StartedAt.String()
StopAt := alert.StoppedAt.String()
@ -31,7 +30,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert {
machineID = alert.Edges.Owner.MachineId
}
outputAlert = models.Alert{
outputAlert := models.Alert{
ID: int64(alert.ID),
MachineID: machineID,
CreatedAt: alert.CreatedAt.Format(time.RFC3339),
@ -58,23 +57,27 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert {
Longitude: alert.SourceLongitude,
},
}
for _, eventItem := range alert.Edges.Events {
var Metas models.Meta
timestamp := eventItem.Time.String()
if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil {
log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err)
}
outputAlert.Events = append(outputAlert.Events, &models.Event{
Timestamp: &timestamp,
Meta: Metas,
})
}
for _, metaItem := range alert.Edges.Metas {
outputAlert.Meta = append(outputAlert.Meta, &models.MetaItems0{
Key: metaItem.Key,
Value: metaItem.Value,
})
}
for _, decisionItem := range alert.Edges.Decisions {
duration := decisionItem.Until.Sub(time.Now().UTC()).String()
outputAlert.Decisions = append(outputAlert.Decisions, &models.Decision{
@ -88,6 +91,7 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert {
ID: int64(decisionItem.ID),
})
}
return &outputAlert
}
@ -97,6 +101,7 @@ func FormatAlerts(result []*ent.Alert) models.AddAlertsRequest {
for _, alertItem := range result {
data = append(data, FormatOneAlert(alertItem))
}
return data
}
@ -107,6 +112,7 @@ func (c *Controller) sendAlertToPluginChannel(alert *models.Alert, profileID uin
select {
case c.PluginChannel <- csplugin.ProfileAlert{ProfileID: profileID, Alert: alert}:
log.Debugf("alert sent to Plugin channel")
break RETRY
default:
log.Warningf("Cannot send alert to Plugin channel (try: %d)", try)
@ -133,7 +139,6 @@ func normalizeScope(scope string) string {
// CreateAlert writes the alerts received in the body to the database
func (c *Controller) CreateAlert(gctx *gin.Context) {
var input models.AddAlertsRequest
claims := jwt.ExtractClaims(gctx)
@ -144,13 +149,16 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
if err := input.Validate(strfmt.Default); err != nil {
c.HandleDBErrors(gctx, err)
return
}
stopFlush := false
for _, alert := range input {
//normalize scope for alert.Source and decisions
// normalize scope for alert.Source and decisions
if alert.Source.Scope != nil {
*alert.Source.Scope = normalizeScope(*alert.Source.Scope)
}
@ -161,15 +169,16 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
}
alert.MachineID = machineID
//generate uuid here for alert
// generate uuid here for alert
alert.UUID = uuid.NewString()
//if coming from cscli, alert already has decisions
// if coming from cscli, alert already has decisions
if len(alert.Decisions) != 0 {
//alert already has a decision (cscli decisions add etc.), generate uuid here
for _, decision := range alert.Decisions {
decision.UUID = uuid.NewString()
}
for pIdx, profile := range c.Profiles {
_, matched, err := profile.EvaluateProfile(alert)
if err != nil {

View file

@ -4,7 +4,6 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
@ -16,23 +15,22 @@ func TestDeleteDecisionRange(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_minibulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json")
// delete by ip wrong
w := lapi.RecordResponse("DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD)
w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by range
w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String())
// delete by range : ensure it was already deleted
w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
}
@ -41,23 +39,23 @@ func TestDeleteDecisionFilter(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_minibulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json")
// delete by ip wrong
w := lapi.RecordResponse("DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD)
w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by ip good
w = lapi.RecordResponse("DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
// delete by scope/value
w = lapi.RecordResponse("DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
}
@ -66,17 +64,17 @@ func TestDeleteDecisionFilterByScenario(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_minibulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json")
// delete by wrong scenario
w := lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD)
w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bff", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by scenario good
w = lapi.RecordResponse("DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions?scenario=crowdsecurity/ssh-bf", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String())
}
@ -85,14 +83,13 @@ func TestGetDecisionFilters(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_minibulk.json")
lapi.InsertAlertFromFile(t, "./tests/alert_minibulk.json")
// Get Decision
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err := readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code := readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 2)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
@ -104,10 +101,9 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision : type filter
w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?type=ban", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code = readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 2)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
@ -122,10 +118,9 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision : scope/value
w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code = readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 1)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
@ -137,10 +132,9 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision : ip filter
w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?ip=91.121.79.179", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code = readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 1)
assert.Equal(t, "crowdsecurity/ssh-bf", *decisions[0].Scenario)
@ -151,10 +145,9 @@ func TestGetDecisionFilters(t *testing.T) {
// assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`)
// Get decision : by range
w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code = readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 2)
assert.Contains(t, []string{*decisions[0].Value, *decisions[1].Value}, "91.121.79.179")
@ -165,13 +158,12 @@ func TestGetDecision(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Get Decision
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, APIKEY)
w := lapi.RecordResponse(t, "GET", "/v1/decisions", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
decisions, code, err := readDecisionsGetResp(w)
require.NoError(t, err)
decisions, code := readDecisionsGetResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions, 3)
/*decisions get doesn't perform deduplication*/
@ -188,7 +180,7 @@ func TestGetDecision(t *testing.T) {
assert.Equal(t, int64(3), decisions[2].ID)
// Get Decision with invalid filter. It should ignore this filter
w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, APIKEY)
w = lapi.RecordResponse(t, "GET", "/v1/decisions?test=test", emptyBody, APIKEY)
assert.Equal(t, 200, w.Code)
assert.Len(t, decisions, 3)
}
@ -197,49 +189,43 @@ func TestDeleteDecisionByID(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
//Have one alerts
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err := readDecisionsStreamResp(w)
require.NoError(t, err)
w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code := readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
// Delete alert with Invalid ID
w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/test", emptyBody, PASSWORD)
assert.Equal(t, 400, w.Code)
errResp, _, err := readDecisionsErrorResp(w)
require.NoError(t, err)
errResp, _ := readDecisionsErrorResp(t, w)
assert.Equal(t, "decision_id must be valid integer", errResp["message"])
// Delete alert with ID that not exist
w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/100", emptyBody, PASSWORD)
assert.Equal(t, 500, w.Code)
errResp, _, err = readDecisionsErrorResp(w)
require.NoError(t, err)
errResp, _ = readDecisionsErrorResp(t, w)
assert.Equal(t, "decision with id '100' doesn't exist: unable to delete", errResp["message"])
//Have one alerts
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
// Delete alert with valid ID
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
resp, _, err := readDecisionsDeleteResp(w)
require.NoError(t, err)
resp, _ := readDecisionsDeleteResp(t, w)
assert.Equal(t, "1", resp.NbDeleted)
//Have one alert (because we delete an alert that has dup targets)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
@ -249,20 +235,18 @@ func TestDeleteDecision(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Delete alert with Invalid filter
w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD)
w := lapi.RecordResponse(t, "DELETE", "/v1/decisions?test=test", emptyBody, PASSWORD)
assert.Equal(t, 500, w.Code)
errResp, _, err := readDecisionsErrorResp(w)
require.NoError(t, err)
errResp, _ := readDecisionsErrorResp(t, w)
assert.Equal(t, "'test' doesn't exist: invalid filter", errResp["message"])
// Delete all alert
w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
resp, _, err := readDecisionsDeleteResp(w)
require.NoError(t, err)
resp, _ := readDecisionsDeleteResp(t, w)
assert.Equal(t, "3", resp.NbDeleted)
}
@ -271,12 +255,11 @@ func TestStreamStartDecisionDedup(t *testing.T) {
lapi := SetupLAPITest(t)
// Create Valid Alert : 3 decisions for 127.0.0.1, longest has id=3
lapi.InsertAlertFromFile("./tests/alert_sample.json")
lapi.InsertAlertFromFile(t, "./tests/alert_sample.json")
// Get Stream, we only get one decision (the longest one)
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err := readDecisionsStreamResp(w)
require.NoError(t, err)
w := lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code := readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
@ -285,13 +268,12 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
// id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip
w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/3", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
// Get Stream, we only get one decision (the longest one, id=2)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
@ -300,13 +282,12 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
// We delete another decision, yet don't receive it in stream, since there's another decision on same IP
w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/2", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
// And get the remaining decision (1)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Empty(t, decisions["deleted"])
assert.Len(t, decisions["new"], 1)
@ -315,13 +296,12 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, "127.0.0.1", *decisions["new"][0].Value)
// We delete the last decision, we receive the delete order
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
w = lapi.RecordResponse(t, "DELETE", "/v1/decisions/1", emptyBody, PASSWORD)
assert.Equal(t, 200, w.Code)
//and now we only get a deleted decision
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code, err = readDecisionsStreamResp(w)
require.NoError(t, err)
w = lapi.RecordResponse(t, "GET", "/v1/decisions/stream?startup=true", emptyBody, APIKEY)
decisions, code = readDecisionsStreamResp(t, w)
assert.Equal(t, 200, code)
assert.Len(t, decisions["deleted"], 1)
assert.Equal(t, int64(1), decisions["deleted"][0].ID)

View file

@ -10,9 +10,9 @@ import (
func TestHeartBeat(t *testing.T) {
lapi := SetupLAPITest(t)
w := lapi.RecordResponse(http.MethodGet, "/v1/heartbeat", emptyBody, "password")
w := lapi.RecordResponse(t, http.MethodGet, "/v1/heartbeat", emptyBody, "password")
assert.Equal(t, 200, w.Code)
w = lapi.RecordResponse("POST", "/v1/heartbeat", emptyBody, "password")
w = lapi.RecordResponse(t, "POST", "/v1/heartbeat", emptyBody, "password")
assert.Equal(t, 405, w.Code)
}

View file

@ -6,20 +6,13 @@ import (
"strings"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func TestLogin(t *testing.T) {
router, config, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITest(t)
body, err := CreateTestMachine(router)
if err != nil {
log.Fatalln(err)
}
body := CreateTestMachine(t, router)
// Login with machine not validated yet
w := httptest.NewRecorder()
@ -28,16 +21,16 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"machine test not validated\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"machine test not validated"}`, w.Body.String())
// Login with machine not exist
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\", \"password\": \"test1\"}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1", "password": "test1"}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"ent: machine not found\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"ent: machine not found"}`, w.Body.String())
// Login with invalid body
w = httptest.NewRecorder()
@ -46,31 +39,28 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"missing: invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"missing: invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Login with invalid format
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test1\"}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test1"}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"validation failure list:\\npassword in body is required\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"validation failure list:\npassword in body is required"}`, w.Body.String())
//Validate machine
err = ValidateMachine("test", config.API.Server.DbConfig)
if err != nil {
log.Fatalln(err)
}
ValidateMachine(t, "test", config.API.Server.DbConfig)
// Login with invalid password
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test1\"}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test1"}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 401, w.Code)
assert.Equal(t, "{\"code\":401,\"message\":\"incorrect Username or Password\"}", w.Body.String())
assert.Equal(t, `{"code":401,"message":"incorrect Username or Password"}`, w.Body.String())
// Login with valid machine
w = httptest.NewRecorder()
@ -79,16 +69,16 @@ func TestLogin(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "\"token\"")
assert.Contains(t, w.Body.String(), "\"expire\"")
assert.Contains(t, w.Body.String(), `"token"`)
assert.Contains(t, w.Body.String(), `"expire"`)
// Login with valid machine + scenarios
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader("{\"machine_id\": \"test\", \"password\": \"test\", \"scenarios\": [\"crowdsecurity/test\", \"crowdsecurity/test2\"]}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers/login", strings.NewReader(`{"machine_id": "test", "password": "test", "scenarios": ["crowdsecurity/test", "crowdsecurity/test2"]}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "\"token\"")
assert.Contains(t, w.Body.String(), "\"expire\"")
assert.Contains(t, w.Body.String(), `"token"`)
assert.Contains(t, w.Body.String(), `"expire"`)
}

View file

@ -7,15 +7,12 @@ import (
"strings"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateMachine(t *testing.T) {
router, _, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, _ := NewAPITest(t)
// Create machine with invalid format
w := httptest.NewRecorder()
@ -24,22 +21,21 @@ func TestCreateMachine(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 400, w.Code)
assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String())
assert.Equal(t, `{"message":"invalid character 'e' in literal true (expecting 'r')"}`, w.Body.String())
// Create machine with invalid input
w = httptest.NewRecorder()
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader("{\"test\": \"test\"}"))
req, _ = http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(`{"test": "test"}`))
req.Header.Add("User-Agent", UserAgent)
router.ServeHTTP(w, req)
assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"validation failure list:\\nmachine_id in body is required\\npassword in body is required\"}", w.Body.String())
assert.Equal(t, `{"message":"validation failure list:\nmachine_id in body is required\npassword in body is required"}`, w.Body.String())
// Create machine
b, err := json.Marshal(MachineTest)
if err != nil {
log.Fatal("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w = httptest.NewRecorder()
@ -52,16 +48,12 @@ func TestCreateMachine(t *testing.T) {
}
func TestCreateMachineWithForwardedFor(t *testing.T) {
router, config, err := NewAPITestForwardedFor(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITestForwardedFor(t)
router.TrustedPlatform = "X-Real-IP"
// Create machine
b, err := json.Marshal(MachineTest)
if err != nil {
log.Fatal("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w := httptest.NewRecorder()
@ -73,25 +65,18 @@ func TestCreateMachineWithForwardedFor(t *testing.T) {
assert.Equal(t, 201, w.Code)
assert.Equal(t, "", w.Body.String())
ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig)
if err != nil {
log.Fatalf("Could not get machine IP : %s", err)
}
ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig)
assert.Equal(t, "1.1.1.1", ip)
}
func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
router, config, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITest(t)
// Create machine
b, err := json.Marshal(MachineTest)
if err != nil {
log.Fatal("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w := httptest.NewRecorder()
@ -103,26 +88,20 @@ func TestCreateMachineWithForwardedForNoConfig(t *testing.T) {
assert.Equal(t, 201, w.Code)
assert.Equal(t, "", w.Body.String())
ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig)
if err != nil {
log.Fatalf("Could not get machine IP : %s", err)
}
ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig)
//For some reason, the IP is empty when running tests
//if no forwarded-for headers are present
assert.Equal(t, "", ip)
}
func TestCreateMachineWithoutForwardedFor(t *testing.T) {
router, config, err := NewAPITestForwardedFor(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, config := NewAPITestForwardedFor(t)
// Create machine
b, err := json.Marshal(MachineTest)
if err != nil {
log.Fatal("unable to marshal MachineTest")
}
require.NoError(t, err)
body := string(b)
w := httptest.NewRecorder()
@ -133,25 +112,17 @@ func TestCreateMachineWithoutForwardedFor(t *testing.T) {
assert.Equal(t, 201, w.Code)
assert.Equal(t, "", w.Body.String())
ip, err := GetMachineIP(*MachineTest.MachineID, config.API.Server.DbConfig)
if err != nil {
log.Fatalf("Could not get machine IP : %s", err)
}
ip := GetMachineIP(t, *MachineTest.MachineID, config.API.Server.DbConfig)
//For some reason, the IP is empty when running tests
//if no forwarded-for headers are present
assert.Equal(t, "", ip)
}
func TestCreateMachineAlreadyExist(t *testing.T) {
router, _, err := NewAPITest(t)
if err != nil {
log.Fatalf("unable to run local API: %s", err)
}
router, _ := NewAPITest(t)
body, err := CreateTestMachine(router)
if err != nil {
log.Fatalln(err)
}
body := CreateTestMachine(t, router)
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodPost, "/v1/watchers", strings.NewReader(body))
@ -164,5 +135,5 @@ func TestCreateMachineAlreadyExist(t *testing.T) {
router.ServeHTTP(w, req)
assert.Equal(t, 403, w.Code)
assert.Equal(t, "{\"message\":\"user 'test': user already exist\"}", w.Body.String())
assert.Equal(t, `{"message":"user 'test': user already exist"}`, w.Body.String())
}

View file

@ -28,6 +28,7 @@ rules:
type match struct {
Type string `yaml:"type"`
Value string `yaml:"value"`
Not bool `yaml:"not,omitempty"`
}
type CustomRule struct {
@ -40,7 +41,8 @@ type CustomRule struct {
Transform []string `yaml:"transform"` //t:lowercase, t:uppercase, etc
And []CustomRule `yaml:"and,omitempty"`
Or []CustomRule `yaml:"or,omitempty"`
BodyType string `yaml:"body_type,omitempty"`
BodyType string `yaml:"body_type,omitempty"`
}
func (v *CustomRule) Convert(ruleType string, appsecRuleName string) (string, []uint32, error) {

View file

@ -8,6 +8,16 @@ func TestVPatchRuleString(t *testing.T) {
rule CustomRule
expected string
}{
{
name: "Collection count",
rule: CustomRule{
Zones: []string{"ARGS"},
Variables: []string{"foo"},
Match: match{Type: "eq", Value: "1"},
Transform: []string{"count"},
},
expected: `SecRule &ARGS_GET:foo "@eq 1" "id:853070236,phase:2,deny,log,msg:'Collection count',tag:'crowdsec-Collection count'"`,
},
{
name: "Base Rule",
rule: CustomRule{
@ -18,6 +28,32 @@ func TestVPatchRuleString(t *testing.T) {
},
expected: `SecRule ARGS_GET:foo "@rx [^a-zA-Z]" "id:2203944045,phase:2,deny,log,msg:'Base Rule',tag:'crowdsec-Base Rule',t:lowercase"`,
},
{
name: "One zone, multi var",
rule: CustomRule{
Zones: []string{"ARGS"},
Variables: []string{"foo", "bar"},
Match: match{Type: "regex", Value: "[^a-zA-Z]"},
Transform: []string{"lowercase"},
},
expected: `SecRule ARGS_GET:foo|ARGS_GET:bar "@rx [^a-zA-Z]" "id:385719930,phase:2,deny,log,msg:'One zone, multi var',tag:'crowdsec-One zone, multi var',t:lowercase"`,
},
{
name: "Base Rule #2",
rule: CustomRule{
Zones: []string{"METHOD"},
Match: match{Type: "startsWith", Value: "toto"},
},
expected: `SecRule REQUEST_METHOD "@beginsWith toto" "id:2759779019,phase:2,deny,log,msg:'Base Rule #2',tag:'crowdsec-Base Rule #2'"`,
},
{
name: "Base Negative Rule",
rule: CustomRule{
Zones: []string{"METHOD"},
Match: match{Type: "startsWith", Value: "toto", Not: true},
},
expected: `SecRule REQUEST_METHOD "!@beginsWith toto" "id:3966251995,phase:2,deny,log,msg:'Base Negative Rule',tag:'crowdsec-Base Negative Rule'"`,
},
{
name: "Multiple Zones",
rule: CustomRule{
@ -28,6 +64,25 @@ func TestVPatchRuleString(t *testing.T) {
},
expected: `SecRule ARGS_GET:foo|ARGS_POST:foo "@rx [^a-zA-Z]" "id:3387135861,phase:2,deny,log,msg:'Multiple Zones',tag:'crowdsec-Multiple Zones',t:lowercase"`,
},
{
name: "Multiple Zones Multi Var",
rule: CustomRule{
Zones: []string{"ARGS", "BODY_ARGS"},
Variables: []string{"foo", "bar"},
Match: match{Type: "regex", Value: "[^a-zA-Z]"},
Transform: []string{"lowercase"},
},
expected: `SecRule ARGS_GET:foo|ARGS_GET:bar|ARGS_POST:foo|ARGS_POST:bar "@rx [^a-zA-Z]" "id:1119773585,phase:2,deny,log,msg:'Multiple Zones Multi Var',tag:'crowdsec-Multiple Zones Multi Var',t:lowercase"`,
},
{
name: "Multiple Zones No Vars",
rule: CustomRule{
Zones: []string{"ARGS", "BODY_ARGS"},
Match: match{Type: "regex", Value: "[^a-zA-Z]"},
Transform: []string{"lowercase"},
},
expected: `SecRule ARGS_GET|ARGS_POST "@rx [^a-zA-Z]" "id:2020110336,phase:2,deny,log,msg:'Multiple Zones No Vars',tag:'crowdsec-Multiple Zones No Vars',t:lowercase"`,
},
{
name: "Basic AND",
rule: CustomRule{

View file

@ -44,6 +44,7 @@ var matchMap map[string]string = map[string]string{
"lt": "@lt",
"gte": "@ge",
"lte": "@le",
"eq": "@eq",
}
var bodyTypeMatch map[string]string = map[string]string{
@ -121,7 +122,20 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an
return ret, nil
}
zone_prefix := ""
variable_prefix := ""
if rule.Transform != nil {
for tidx, transform := range rule.Transform {
if transform == "count" {
zone_prefix = "&"
rule.Transform[tidx] = ""
}
}
}
for idx, zone := range rule.Zones {
if idx > 0 {
r.WriteByte('|')
}
mappedZone, ok := zonesMap[zone]
if !ok {
return nil, fmt.Errorf("unknown zone '%s'", zone)
@ -130,10 +144,10 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an
r.WriteString(mappedZone)
} else {
for j, variable := range rule.Variables {
if idx > 0 || j > 0 {
if j > 0 {
r.WriteByte('|')
}
r.WriteString(fmt.Sprintf("%s:%s", mappedZone, variable))
r.WriteString(fmt.Sprintf("%s%s:%s%s", zone_prefix, mappedZone, variable_prefix, variable))
}
}
}
@ -141,7 +155,11 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an
if rule.Match.Type != "" {
if match, ok := matchMap[rule.Match.Type]; ok {
r.WriteString(fmt.Sprintf(`"%s %s"`, match, rule.Match.Value))
prefix := ""
if rule.Match.Not {
prefix = "!"
}
r.WriteString(fmt.Sprintf(`"%s%s %s"`, prefix, match, rule.Match.Value))
} else {
return nil, fmt.Errorf("unknown match type '%s'", rule.Match.Type)
}
@ -152,6 +170,9 @@ func (m *ModsecurityRule) buildRules(rule *CustomRule, appsecRuleName string, an
if rule.Transform != nil {
for _, transform := range rule.Transform {
if transform == "" {
continue
}
r.WriteByte(',')
if mappedTransform, ok := transformMap[transform]; ok {
r.WriteString(mappedTransform)

View file

@ -11,6 +11,7 @@ import (
"regexp"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)
@ -267,7 +268,7 @@ func (r *ReqDumpFilter) ToJSON() error {
}
// Generate a ParsedRequest from a http.Request. ParsedRequest can be consumed by the App security Engine
func NewParsedRequestFromRequest(r *http.Request) (ParsedRequest, error) {
func NewParsedRequestFromRequest(r *http.Request, logger *logrus.Entry) (ParsedRequest, error) {
var err error
contentLength := r.ContentLength
if contentLength < 0 {
@ -282,26 +283,23 @@ func NewParsedRequestFromRequest(r *http.Request) (ParsedRequest, error) {
}
}
// the real source of the request is set in 'x-client-ip'
clientIP := r.Header.Get(IPHeaderName)
if clientIP == "" {
return ParsedRequest{}, fmt.Errorf("missing '%s' header", IPHeaderName)
}
// the real target Host of the request is set in 'x-client-host'
clientHost := r.Header.Get(HostHeaderName)
if clientHost == "" {
return ParsedRequest{}, fmt.Errorf("missing '%s' header", HostHeaderName)
}
// the real URI of the request is set in 'x-client-uri'
clientURI := r.Header.Get(URIHeaderName)
if clientURI == "" {
return ParsedRequest{}, fmt.Errorf("missing '%s' header", URIHeaderName)
}
// the real VERB of the request is set in 'x-client-uri'
clientMethod := r.Header.Get(VerbHeaderName)
if clientMethod == "" {
return ParsedRequest{}, fmt.Errorf("missing '%s' header", VerbHeaderName)
}
clientHost := r.Header.Get(HostHeaderName)
if clientHost == "" { //this might be empty
logger.Debugf("missing '%s' header", HostHeaderName)
}
// delete those headers before coraza process the request
delete(r.Header, IPHeaderName)

View file

@ -61,36 +61,51 @@ func (a *CTICfg) Load() error {
if a.Key == nil {
*a.Enabled = false
}
if a.Key != nil && *a.Key == "" {
return fmt.Errorf("empty cti key")
}
if a.Enabled == nil {
a.Enabled = new(bool)
*a.Enabled = true
}
if a.CacheTimeout == nil {
a.CacheTimeout = new(time.Duration)
*a.CacheTimeout = 10 * time.Minute
}
if a.CacheSize == nil {
a.CacheSize = new(int)
*a.CacheSize = 100
}
return nil
}
func (o *OnlineApiClientCfg) Load() error {
o.Credentials = new(ApiCredentialsCfg)
fcontent, err := os.ReadFile(o.CredentialsFilePath)
if err != nil {
return fmt.Errorf("failed to read api server credentials configuration file '%s': %w", o.CredentialsFilePath, err)
return err
}
err = yaml.UnmarshalStrict(fcontent, o.Credentials)
if err != nil {
return fmt.Errorf("failed unmarshaling api server credentials configuration file '%s': %w", o.CredentialsFilePath, err)
}
if o.Credentials.Login == "" || o.Credentials.Password == "" || o.Credentials.URL == "" {
log.Warningf("can't load CAPI credentials from '%s' (missing field)", o.CredentialsFilePath)
switch {
case o.Credentials.Login == "":
log.Warningf("can't load CAPI credentials from '%s' (missing login field)", o.CredentialsFilePath)
o.Credentials = nil
case o.Credentials.Password == "":
log.Warningf("can't load CAPI credentials from '%s' (missing password field)", o.CredentialsFilePath)
o.Credentials = nil
case o.Credentials.URL == "":
log.Warningf("can't load CAPI credentials from '%s' (missing url field)", o.CredentialsFilePath)
o.Credentials = nil
}
@ -99,14 +114,17 @@ func (o *OnlineApiClientCfg) Load() error {
func (l *LocalApiClientCfg) Load() error {
patcher := yamlpatch.NewPatcher(l.CredentialsFilePath, ".local")
fcontent, err := patcher.MergedPatchContent()
if err != nil {
return err
}
err = yaml.UnmarshalStrict(fcontent, &l.Credentials)
if err != nil {
return fmt.Errorf("failed unmarshaling api client credential configuration file '%s': %w", l.CredentialsFilePath, err)
}
if l.Credentials == nil || l.Credentials.URL == "" {
return fmt.Errorf("no credentials or URL found in api client configuration '%s'", l.CredentialsFilePath)
}
@ -137,9 +155,11 @@ func (l *LocalApiClientCfg) Load() error {
if err != nil {
log.Warningf("Error loading system CA certificates: %s", err)
}
if caCertPool == nil {
caCertPool = x509.NewCertPool()
}
caCertPool.AppendCertsFromPEM(caCert)
apiclient.CaCertPool = caCertPool
}
@ -160,12 +180,15 @@ func (lapiCfg *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) {
trustedIPs := make([]net.IPNet, 0)
for _, ip := range lapiCfg.TrustedIPs {
cidr := toValidCIDR(ip)
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, err
}
trustedIPs = append(trustedIPs, *ipNet)
}
return trustedIPs, nil
}
@ -177,6 +200,7 @@ func toValidCIDR(ip string) string {
if strings.Contains(ip, ":") {
return ip + "/128"
}
return ip + "/32"
}
@ -327,24 +351,30 @@ func parseCapiWhitelists(fd io.Reader) (*CapiWhitelist, error) {
if errors.Is(err, io.EOF) {
return nil, fmt.Errorf("empty file")
}
return nil, err
}
ret := &CapiWhitelist{
Ips: make([]net.IP, len(fromCfg.Ips)),
Cidrs: make([]*net.IPNet, len(fromCfg.Cidrs)),
}
for idx, v := range fromCfg.Ips {
ip := net.ParseIP(v)
if ip == nil {
return nil, fmt.Errorf("invalid IP address: %s", v)
}
ret.Ips[idx] = ip
}
for idx, v := range fromCfg.Cidrs {
_, tnet, err := net.ParseCIDR(v)
if err != nil {
return nil, err
}
ret.Cidrs[idx] = tnet
}
@ -356,10 +386,6 @@ func (s *LocalApiServerCfg) LoadCapiWhitelists() error {
return nil
}
if _, err := os.Stat(s.CapiWhitelistsPath); os.IsNotExist(err) {
return fmt.Errorf("capi whitelist file '%s' does not exist", s.CapiWhitelistsPath)
}
fd, err := os.Open(s.CapiWhitelistsPath)
if err != nil {
return fmt.Errorf("while opening capi whitelist file: %s", err)

View file

@ -116,7 +116,7 @@ func TestLoadOnlineApiClientCfg(t *testing.T) {
CredentialsFilePath: "./testdata/nonexist_online-api-secrets.yaml",
},
expected: &ApiCredentialsCfg{},
expectedErr: "failed to read api server credentials",
expectedErr: "open ./testdata/nonexist_online-api-secrets.yaml: " + cstest.FileNotFoundMessage,
},
}
@ -254,6 +254,7 @@ func TestLoadAPIServer(t *testing.T) {
func mustParseCIDRNet(t *testing.T, s string) *net.IPNet {
_, ipNet, err := net.ParseCIDR(s)
require.NoError(t, err)
return ipNet
}

View file

@ -149,7 +149,7 @@ func (s *PluginSuite) TestBrokerNoThreshold() {
t := s.T()
pb, err := s.InitBroker(nil)
assert.NoError(t, err)
require.NoError(t, err)
tomb := tomb.Tomb{}
go pb.Run(&tomb)
@ -182,7 +182,7 @@ func (s *PluginSuite) TestBrokerNoThreshold() {
err = json.Unmarshal(content, &alerts)
log.Printf("content-> %s", content)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 1)
}
@ -199,7 +199,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
s.writeconfig(cfg)
pb, err := s.InitBroker(nil)
assert.NoError(t, err)
require.NoError(t, err)
tomb := tomb.Tomb{}
go pb.Run(&tomb)
@ -215,11 +215,11 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_TimeFirst() {
time.Sleep(1 * time.Second)
// after 1 seconds, we should have data
content, err := os.ReadFile("./out")
assert.NoError(t, err)
require.NoError(t, err)
var alerts []models.Alert
err = json.Unmarshal(content, &alerts)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 3)
}
@ -235,7 +235,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
s.writeconfig(cfg)
pb, err := s.InitBroker(nil)
assert.NoError(t, err)
require.NoError(t, err)
tomb := tomb.Tomb{}
go pb.Run(&tomb)
@ -259,7 +259,7 @@ func (s *PluginSuite) TestBrokerRunGroupAndTimeThreshold_CountFirst() {
var alerts []models.Alert
err = json.Unmarshal(content, &alerts)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 4)
}
@ -275,7 +275,7 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
s.writeconfig(cfg)
pb, err := s.InitBroker(nil)
assert.NoError(t, err)
require.NoError(t, err)
tomb := tomb.Tomb{}
go pb.Run(&tomb)
@ -306,11 +306,11 @@ func (s *PluginSuite) TestBrokerRunGroupThreshold() {
// two notifications, one with 4 alerts, one with 2 alerts
err = decoder.Decode(&alerts)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 4)
err = decoder.Decode(&alerts)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 2)
err = decoder.Decode(&alerts)
@ -328,7 +328,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
s.writeconfig(cfg)
pb, err := s.InitBroker(nil)
assert.NoError(t, err)
require.NoError(t, err)
tomb := tomb.Tomb{}
go pb.Run(&tomb)
@ -348,7 +348,7 @@ func (s *PluginSuite) TestBrokerRunTimeThreshold() {
var alerts []models.Alert
err = json.Unmarshal(content, &alerts)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 1)
}
@ -358,7 +358,7 @@ func (s *PluginSuite) TestBrokerRunSimple() {
t := s.T()
pb, err := s.InitBroker(nil)
assert.NoError(t, err)
require.NoError(t, err)
tomb := tomb.Tomb{}
go pb.Run(&tomb)
@ -382,11 +382,11 @@ func (s *PluginSuite) TestBrokerRunSimple() {
// two notifications, one alert each
err = decoder.Decode(&alerts)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 1)
err = decoder.Decode(&alerts)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 1)
err = decoder.Decode(&alerts)

View file

@ -70,7 +70,7 @@ func (s *PluginSuite) TestBrokerRun() {
t := s.T()
pb, err := s.InitBroker(nil)
assert.NoError(t, err)
require.NoError(t, err)
tomb := tomb.Tomb{}
go pb.Run(&tomb)
@ -94,11 +94,11 @@ func (s *PluginSuite) TestBrokerRun() {
// two notifications, one alert each
err = decoder.Decode(&alerts)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 1)
err = decoder.Decode(&alerts)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, alerts, 1)
err = decoder.Decode(&alerts)

View file

@ -6,7 +6,6 @@ import (
"github.com/antonmedv/expr"
"github.com/antonmedv/expr/vm"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
@ -22,19 +21,23 @@ type Runtime struct {
Logger *log.Entry `json:"-" yaml:"-"`
}
var defaultDuration = "4h"
const defaultDuration = "4h"
func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) {
var err error
profilesRuntime := make([]*Runtime, 0)
for _, profile := range profilesCfg {
var runtimeFilter, runtimeDurationExpr *vm.Program
runtime := &Runtime{}
xlog := log.New()
if err := types.ConfigureLogger(xlog); err != nil {
log.Fatalf("While creating profiles-specific logger : %s", err)
}
xlog.SetLevel(log.InfoLevel)
runtime.Logger = xlog.WithFields(log.Fields{
"type": "profile",
@ -43,17 +46,20 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) {
runtime.RuntimeFilters = make([]*vm.Program, len(profile.Filters))
runtime.Cfg = profile
if runtime.Cfg.OnSuccess != "" && runtime.Cfg.OnSuccess != "continue" && runtime.Cfg.OnSuccess != "break" {
return []*Runtime{}, fmt.Errorf("invalid 'on_success' for '%s': %s", profile.Name, runtime.Cfg.OnSuccess)
}
if runtime.Cfg.OnFailure != "" && runtime.Cfg.OnFailure != "continue" && runtime.Cfg.OnFailure != "break" && runtime.Cfg.OnFailure != "apply" {
return []*Runtime{}, fmt.Errorf("invalid 'on_failure' for '%s' : %s", profile.Name, runtime.Cfg.OnFailure)
}
for fIdx, filter := range profile.Filters {
if runtime.Cfg.OnSuccess != "" && runtime.Cfg.OnSuccess != "continue" && runtime.Cfg.OnSuccess != "break" {
return nil, fmt.Errorf("invalid 'on_success' for '%s': %s", profile.Name, runtime.Cfg.OnSuccess)
}
if runtime.Cfg.OnFailure != "" && runtime.Cfg.OnFailure != "continue" && runtime.Cfg.OnFailure != "break" && runtime.Cfg.OnFailure != "apply" {
return nil, fmt.Errorf("invalid 'on_failure' for '%s' : %s", profile.Name, runtime.Cfg.OnFailure)
}
for fIdx, filter := range profile.Filters {
if runtimeFilter, err = expr.Compile(filter, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil {
return []*Runtime{}, errors.Wrapf(err, "error compiling filter of '%s'", profile.Name)
return nil, fmt.Errorf("error compiling filter of '%s': %w", profile.Name, err)
}
runtime.RuntimeFilters[fIdx] = runtimeFilter
if profile.Debug != nil && *profile.Debug {
runtime.Logger.Logger.SetLevel(log.DebugLevel)
@ -62,8 +68,9 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) {
if profile.DurationExpr != "" {
if runtimeDurationExpr, err = expr.Compile(profile.DurationExpr, exprhelpers.GetExprOptions(map[string]interface{}{"Alert": &models.Alert{}})...); err != nil {
return []*Runtime{}, errors.Wrapf(err, "error compiling duration_expr of %s", profile.Name)
return nil, fmt.Errorf("error compiling duration_expr of %s: %w", profile.Name, err)
}
runtime.RuntimeDurationExpr = runtimeDurationExpr
}
@ -76,14 +83,16 @@ func NewProfile(profilesCfg []*csconfig.ProfileCfg) ([]*Runtime, error) {
runtime.Logger.Warningf("No duration specified for %s, using default duration %s", profile.Name, defaultDuration)
duration = defaultDuration
}
if _, err := time.ParseDuration(duration); err != nil {
return []*Runtime{}, errors.Wrapf(err, "error parsing duration '%s' of %s", duration, profile.Name)
return nil, fmt.Errorf("error parsing duration '%s' of %s: %w", duration, profile.Name, err)
}
}
}
profilesRuntime = append(profilesRuntime, runtime)
}
return profilesRuntime, nil
}
@ -110,30 +119,29 @@ func (Profile *Runtime) GenerateDecisionFromProfile(Alert *models.Alert) ([]*mod
*decision.Scope = *Alert.Source.Scope
}
/*some fields are populated from the reference object : duration, scope, type*/
decision.Duration = new(string)
if refDecision.Duration != nil {
*decision.Duration = *refDecision.Duration
}
if Profile.Cfg.DurationExpr != "" && Profile.RuntimeDurationExpr != nil {
profileDebug := false
if Profile.Cfg.Debug != nil && *Profile.Cfg.Debug {
profileDebug = true
}
duration, err := exprhelpers.Run(Profile.RuntimeDurationExpr, map[string]interface{}{"Alert": Alert}, Profile.Logger, profileDebug)
if err != nil {
Profile.Logger.Warningf("Failed to run duration_expr : %v", err)
*decision.Duration = *refDecision.Duration
} else {
durationStr := fmt.Sprint(duration)
if _, err := time.ParseDuration(durationStr); err != nil {
Profile.Logger.Warningf("Failed to parse expr duration result '%s'", duration)
*decision.Duration = *refDecision.Duration
} else {
*decision.Duration = durationStr
}
}
} else {
if refDecision.Duration == nil {
*decision.Duration = defaultDuration
}
*decision.Duration = *refDecision.Duration
}
decision.Type = new(string)
@ -144,13 +152,16 @@ func (Profile *Runtime) GenerateDecisionFromProfile(Alert *models.Alert) ([]*mod
*decision.Value = *Alert.Source.Value
decision.Origin = new(string)
*decision.Origin = types.CrowdSecOrigin
if refDecision.Origin != nil {
*decision.Origin = fmt.Sprintf("%s/%s", *decision.Origin, *refDecision.Origin)
}
decision.Scenario = new(string)
*decision.Scenario = *Alert.Scenario
decisions = append(decisions, &decision)
}
return decisions, nil
}
@ -159,16 +170,19 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision
var decisions []*models.Decision
matched := false
for eIdx, expression := range Profile.RuntimeFilters {
debugProfile := false
if Profile.Cfg.Debug != nil && *Profile.Cfg.Debug {
debugProfile = true
}
output, err := exprhelpers.Run(expression, map[string]interface{}{"Alert": Alert}, Profile.Logger, debugProfile)
if err != nil {
Profile.Logger.Warningf("failed to run profile expr for %s : %v", Profile.Cfg.Name, err)
return nil, matched, errors.Wrapf(err, "while running expression %s", Profile.Cfg.Filters[eIdx])
Profile.Logger.Warningf("failed to run profile expr for %s: %v", Profile.Cfg.Name, err)
return nil, matched, fmt.Errorf("while running expression %s: %w", Profile.Cfg.Filters[eIdx], err)
}
switch out := output.(type) {
case bool:
if out {
@ -176,7 +190,7 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision
/*the expression matched, create the associated decision*/
subdecisions, err := Profile.GenerateDecisionFromProfile(Alert)
if err != nil {
return nil, matched, errors.Wrapf(err, "while generating decision from profile %s", Profile.Cfg.Name)
return nil, matched, fmt.Errorf("while generating decision from profile %s: %w", Profile.Cfg.Name, err)
}
decisions = append(decisions, subdecisions...)
@ -189,9 +203,7 @@ func (Profile *Runtime) EvaluateProfile(Alert *models.Alert) ([]*models.Decision
default:
return nil, matched, fmt.Errorf("unexpected type %t (%v) while running '%s'", output, output, Profile.Cfg.Filters[eIdx])
}
}
return decisions, matched, nil

View file

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/ptr"
)
@ -36,25 +37,30 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
// wip
func fireHandler(req *http.Request) *http.Response {
var err error
apiKey := req.Header.Get("x-api-key")
if apiKey != validApiKey {
log.Warningf("invalid api key: %s", apiKey)
return &http.Response{
StatusCode: http.StatusForbidden,
Body: nil,
Header: make(http.Header),
}
}
//unmarshal data
if fireResponses == nil {
page1, err := os.ReadFile("tests/fire-page1.json")
if err != nil {
panic("can't read file")
}
page2, err := os.ReadFile("tests/fire-page2.json")
if err != nil {
panic("can't read file")
}
fireResponses = []string{string(page1), string(page2)}
}
//let's assume we have two valid pages.
@ -70,6 +76,7 @@ func fireHandler(req *http.Request) *http.Response {
//how to react if you give a page number that is too big ?
if page > len(fireResponses) {
log.Warningf(" page too big %d vs %d", page, len(fireResponses))
emptyResponse := `{
"_links": {
"first": {
@ -82,8 +89,10 @@ func fireHandler(req *http.Request) *http.Response {
"items": []
}
`
return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(emptyResponse))}
}
reader := io.NopCloser(strings.NewReader(fireResponses[page-1]))
//we should care about limit too
return &http.Response{
@ -106,6 +115,7 @@ func smokeHandler(req *http.Request) *http.Response {
}
requestedIP := strings.Split(req.URL.Path, "/")[3]
response, ok := smokeResponses[requestedIP]
if !ok {
return &http.Response{
@ -135,6 +145,7 @@ func rateLimitedHandler(req *http.Request) *http.Response {
Header: make(http.Header),
}
}
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Body: nil,
@ -151,7 +162,9 @@ func searchHandler(req *http.Request) *http.Response {
Header: make(http.Header),
}
}
url, _ := url.Parse(req.URL.String())
ipsParam := url.Query().Get("ips")
if ipsParam == "" {
return &http.Response{
@ -163,6 +176,7 @@ func searchHandler(req *http.Request) *http.Response {
totalIps := 0
notFound := 0
ips := strings.Split(ipsParam, ",")
for _, ip := range ips {
_, ok := smokeResponses[ip]
@ -172,12 +186,15 @@ func searchHandler(req *http.Request) *http.Response {
notFound++
}
}
response := fmt.Sprintf(`{"total": %d, "not_found": %d, "items": [`, totalIps, notFound)
for _, ip := range ips {
response += smokeResponses[ip]
}
response += "]}"
reader := io.NopCloser(strings.NewReader(response))
return &http.Response{
StatusCode: http.StatusOK,
Body: reader,
@ -190,7 +207,7 @@ func TestBadFireAuth(t *testing.T) {
Transport: RoundTripFunc(fireHandler),
}))
_, err := ctiClient.Fire(FireParams{})
assert.EqualError(t, err, ErrUnauthorized.Error())
require.EqualError(t, err, ErrUnauthorized.Error())
}
func TestFireOk(t *testing.T) {
@ -198,19 +215,19 @@ func TestFireOk(t *testing.T) {
Transport: RoundTripFunc(fireHandler),
}))
data, err := cticlient.Fire(FireParams{})
assert.Equal(t, err, nil)
assert.Equal(t, len(data.Items), 3)
assert.Equal(t, data.Items[0].Ip, "1.2.3.4")
require.NoError(t, err)
assert.Len(t, data.Items, 3)
assert.Equal(t, "1.2.3.4", data.Items[0].Ip)
//page 1 is the default
data, err = cticlient.Fire(FireParams{Page: ptr.Of(1)})
assert.Equal(t, err, nil)
assert.Equal(t, len(data.Items), 3)
assert.Equal(t, data.Items[0].Ip, "1.2.3.4")
require.NoError(t, err)
assert.Len(t, data.Items, 3)
assert.Equal(t, "1.2.3.4", data.Items[0].Ip)
//page 2
data, err = cticlient.Fire(FireParams{Page: ptr.Of(2)})
assert.Equal(t, err, nil)
assert.Equal(t, len(data.Items), 3)
assert.Equal(t, data.Items[0].Ip, "4.2.3.4")
require.NoError(t, err)
assert.Len(t, data.Items, 3)
assert.Equal(t, "4.2.3.4", data.Items[0].Ip)
}
func TestFirePaginator(t *testing.T) {
@ -219,17 +236,16 @@ func TestFirePaginator(t *testing.T) {
}))
paginator := NewFirePaginator(cticlient, FireParams{})
items, err := paginator.Next()
assert.Equal(t, err, nil)
assert.Equal(t, len(items), 3)
assert.Equal(t, items[0].Ip, "1.2.3.4")
require.NoError(t, err)
assert.Len(t, items, 3)
assert.Equal(t, "1.2.3.4", items[0].Ip)
items, err = paginator.Next()
assert.Equal(t, err, nil)
assert.Equal(t, len(items), 3)
assert.Equal(t, items[0].Ip, "4.2.3.4")
require.NoError(t, err)
assert.Len(t, items, 3)
assert.Equal(t, "4.2.3.4", items[0].Ip)
items, err = paginator.Next()
assert.Equal(t, err, nil)
assert.Equal(t, len(items), 0)
require.NoError(t, err)
assert.Empty(t, items)
}
func TestBadSmokeAuth(t *testing.T) {
@ -237,13 +253,14 @@ func TestBadSmokeAuth(t *testing.T) {
Transport: RoundTripFunc(smokeHandler),
}))
_, err := ctiClient.GetIPInfo("1.1.1.1")
assert.EqualError(t, err, ErrUnauthorized.Error())
require.EqualError(t, err, ErrUnauthorized.Error())
}
func TestSmokeInfoValidIP(t *testing.T) {
ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{
Transport: RoundTripFunc(smokeHandler),
}))
resp, err := ctiClient.GetIPInfo("1.1.1.1")
if err != nil {
t.Fatalf("failed to get ip info: %s", err)
@ -257,6 +274,7 @@ func TestSmokeUnknownIP(t *testing.T) {
ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{
Transport: RoundTripFunc(smokeHandler),
}))
resp, err := ctiClient.GetIPInfo("42.42.42.42")
if err != nil {
t.Fatalf("failed to get ip info: %s", err)
@ -270,20 +288,22 @@ func TestRateLimit(t *testing.T) {
Transport: RoundTripFunc(rateLimitedHandler),
}))
_, err := ctiClient.GetIPInfo("1.1.1.1")
assert.EqualError(t, err, ErrLimit.Error())
require.EqualError(t, err, ErrLimit.Error())
}
func TestSearchIPs(t *testing.T) {
ctiClient := NewCrowdsecCTIClient(WithAPIKey(validApiKey), WithHTTPClient(&http.Client{
Transport: RoundTripFunc(searchHandler),
}))
resp, err := ctiClient.SearchIPs([]string{"1.1.1.1", "42.42.42.42"})
if err != nil {
t.Fatalf("failed to search ips: %s", err)
}
assert.Equal(t, 1, resp.Total)
assert.Equal(t, 1, resp.NotFound)
assert.Equal(t, 1, len(resp.Items))
assert.Len(t, resp.Items, 1)
assert.Equal(t, "1.1.1.1", resp.Items[0].Ip)
}

View file

@ -88,27 +88,28 @@ func getSampleSmokeItem() SmokeItem {
},
},
}
return emptyItem
}
func TestBasicSmokeItem(t *testing.T) {
item := getSampleSmokeItem()
assert.Equal(t, item.GetAttackDetails(), []string{"ssh:bruteforce"})
assert.Equal(t, item.GetBehaviors(), []string{"ssh:bruteforce"})
assert.Equal(t, item.GetMaliciousnessScore(), float32(0.1))
assert.Equal(t, item.IsPartOfCommunityBlocklist(), false)
assert.Equal(t, item.GetBackgroundNoiseScore(), int(3))
assert.Equal(t, item.GetFalsePositives(), []string{})
assert.Equal(t, item.IsFalsePositive(), false)
assert.Equal(t, []string{"ssh:bruteforce"}, item.GetAttackDetails())
assert.Equal(t, []string{"ssh:bruteforce"}, item.GetBehaviors())
assert.InDelta(t, 0.1, item.GetMaliciousnessScore(), 0.000001)
assert.False(t, item.IsPartOfCommunityBlocklist())
assert.Equal(t, 3, item.GetBackgroundNoiseScore())
assert.Equal(t, []string{}, item.GetFalsePositives())
assert.False(t, item.IsFalsePositive())
}
func TestEmptySmokeItem(t *testing.T) {
item := SmokeItem{}
assert.Equal(t, item.GetAttackDetails(), []string{})
assert.Equal(t, item.GetBehaviors(), []string{})
assert.Equal(t, item.GetMaliciousnessScore(), float32(0.0))
assert.Equal(t, item.IsPartOfCommunityBlocklist(), false)
assert.Equal(t, item.GetBackgroundNoiseScore(), int(0))
assert.Equal(t, item.GetFalsePositives(), []string{})
assert.Equal(t, item.IsFalsePositive(), false)
assert.Equal(t, []string{}, item.GetAttackDetails())
assert.Equal(t, []string{}, item.GetBehaviors())
assert.InDelta(t, 0.0, item.GetMaliciousnessScore(), 0)
assert.False(t, item.IsPartOfCommunityBlocklist())
assert.Equal(t, 0, item.GetBackgroundNoiseScore())
assert.Equal(t, []string{}, item.GetFalsePositives())
assert.False(t, item.IsFalsePositive())
}

View file

@ -56,6 +56,7 @@ func downloadFile(url string, destPath string) error {
// if the remote has no modification date, but local file has been modified > a week ago, update.
func needsUpdate(destPath string, url string, logger *logrus.Logger) bool {
fileInfo, err := os.Stat(destPath)
switch {
case os.IsNotExist(err):
return true
@ -89,6 +90,7 @@ func needsUpdate(destPath string, url string, logger *logrus.Logger) bool {
if localIsOld {
logger.Infof("no last modified date for %s, but local file is older than %s", url, shelfLife)
}
return localIsOld
}
@ -129,6 +131,7 @@ func downloadDataSet(dataFolder string, force bool, reader io.Reader, logger *lo
if force || needsUpdate(destPath, dataS.SourceURL, logger) {
logger.Debugf("downloading %s in %s", dataS.SourceURL, destPath)
if err := downloadFile(dataS.SourceURL, destPath); err != nil {
return fmt.Errorf("while getting data: %w", err)
}

View file

@ -7,10 +7,10 @@ import (
"io"
"os"
"path"
"slices"
"strings"
"github.com/sirupsen/logrus"
"slices"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
)
@ -97,6 +97,10 @@ func (h *Hub) parseIndex() error {
item.FileName = path.Base(item.RemotePath)
item.logMissingSubItems()
if item.latestHash() == "" {
h.logger.Errorf("invalid hub item %s: latest version missing from index", item.FQName())
}
}
}

View file

@ -12,7 +12,6 @@ import (
func TestInitHubUpdate(t *testing.T) {
hub := envSetup(t)
remote := &RemoteHubCfg{
URLTemplate: mockURLTemplate,
Branch: "master",

View file

@ -4,10 +4,10 @@ import (
"encoding/json"
"fmt"
"path/filepath"
"slices"
"github.com/Masterminds/semver/v3"
"github.com/enescakir/emoji"
"slices"
)
const (
@ -440,3 +440,15 @@ func (i *Item) addTaint(sub *Item) {
ancestor.addTaint(sub)
}
}
// latestHash() returns the hash of the latest version of the item.
// if it's missing, the index file has been manually modified or got corrupted.
func (i *Item) latestHash() string {
for k, v := range i.Versions {
if k == i.Version {
return v.Digest
}
}
return ""
}

View file

@ -50,7 +50,7 @@ func (i *Item) Install(force bool, downloadOnly bool) error {
filePath, err := i.downloadLatest(force, true)
if err != nil {
return fmt.Errorf("while downloading %s: %w", i.Name, err)
return err
}
if downloadOnly {

View file

@ -3,7 +3,6 @@ package cwhub
import (
"fmt"
"os"
"slices"
)

View file

@ -6,6 +6,7 @@ import (
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
@ -82,14 +83,14 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (string, error) {
i.hub.logger.Tracef("collection, recurse")
if _, err := sub.downloadLatest(overwrite, updateOnly); err != nil {
return "", fmt.Errorf("while downloading %s: %w", sub.Name, err)
return "", err
}
}
downloaded := sub.State.Downloaded
if _, err := sub.download(overwrite); err != nil {
return "", fmt.Errorf("while downloading %s: %w", sub.Name, err)
return "", err
}
// We need to enable an item when it has been added to a collection since latest release of the collection.
@ -108,7 +109,7 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (string, error) {
ret, err := i.download(overwrite)
if err != nil {
return "", fmt.Errorf("failed to download item: %w", err)
return "", err
}
return ret, nil
@ -116,6 +117,10 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (string, error) {
// FetchLatest downloads the latest item from the hub, verifies the hash and returns the content and the used url.
func (i *Item) FetchLatest() ([]byte, string, error) {
if i.latestHash() == "" {
return nil, "", errors.New("latest hash missing from index")
}
url, err := i.hub.remote.urlTo(i.RemotePath)
if err != nil {
return nil, "", fmt.Errorf("failed to build request: %w", err)
@ -146,7 +151,7 @@ func (i *Item) FetchLatest() ([]byte, string, error) {
i.hub.logger.Errorf("Downloaded version doesn't match index, please 'hub update'")
i.hub.logger.Debugf("got %s, expected %s", meow, i.Versions[i.Version].Digest)
return nil, "", fmt.Errorf("invalid download hash for %s", i.Name)
return nil, "", fmt.Errorf("invalid download hash")
}
return body, url, nil
@ -180,7 +185,12 @@ func (i *Item) download(overwrite bool) (string, error) {
body, url, err := i.FetchLatest()
if err != nil {
return "", fmt.Errorf("while downloading %s: %w", url, err)
what := i.Name
if url != "" {
what += " from " + url
}
return "", fmt.Errorf("while downloading %s: %w", what, err)
}
// all good, install

View file

@ -12,7 +12,6 @@ import (
// We expect the new scenario to be installed.
func TestUpgradeItemNewScenarioInCollection(t *testing.T) {
hub := envSetup(t)
item := hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection")
// fresh install of collection
@ -65,7 +64,6 @@ func TestUpgradeItemNewScenarioInCollection(t *testing.T) {
// Upgrade should install should not enable/download the disabled scenario.
func TestUpgradeItemInDisabledScenarioShouldNotBeInstalled(t *testing.T) {
hub := envSetup(t)
item := hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection")
// fresh install of collection
@ -127,7 +125,6 @@ func getHubOrFail(t *testing.T, local *csconfig.LocalHubCfg, remote *RemoteHubCf
// Upgrade should install and enable the newly added scenario.
func TestUpgradeItemNewScenarioIsInstalledWhenReferencedScenarioIsDisabled(t *testing.T) {
hub := envSetup(t)
item := hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection")
// fresh install of collection

View file

@ -7,13 +7,13 @@ import (
"io"
"os"
"path/filepath"
"slices"
"sort"
"strings"
"github.com/Masterminds/semver/v3"
"github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
"slices"
)
func isYAMLFileName(path string) bool {

32
pkg/dumps/bucket_dump.go Normal file
View file

@ -0,0 +1,32 @@
package dumps
import (
"io"
"os"
"github.com/crowdsecurity/crowdsec/pkg/types"
"gopkg.in/yaml.v2"
)
type BucketPourInfo map[string][]types.Event
func LoadBucketPourDump(filepath string) (*BucketPourInfo, error) {
dumpData, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer dumpData.Close()
results, err := io.ReadAll(dumpData)
if err != nil {
return nil, err
}
var bucketDump BucketPourInfo
if err := yaml.Unmarshal(results, &bucketDump); err != nil {
return nil, err
}
return &bucketDump, nil
}

319
pkg/dumps/parser_dump.go Normal file
View file

@ -0,0 +1,319 @@
package dumps
import (
"fmt"
"io"
"os"
"sort"
"strings"
"time"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/crowdsecurity/go-cs-lib/maptools"
"github.com/enescakir/emoji"
"github.com/fatih/color"
diff "github.com/r3labs/diff/v2"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
type ParserResult struct {
Idx int
Evt types.Event
Success bool
}
type ParserResults map[string]map[string][]ParserResult
type DumpOpts struct {
Details bool
SkipOk bool
ShowNotOkParsers bool
}
func LoadParserDump(filepath string) (*ParserResults, error) {
dumpData, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer dumpData.Close()
results, err := io.ReadAll(dumpData)
if err != nil {
return nil, err
}
pdump := ParserResults{}
if err := yaml.Unmarshal(results, &pdump); err != nil {
return nil, err
}
/* we know that some variables should always be set,
let's check if they're present in last parser output of last stage */
stages := maptools.SortedKeys(pdump)
var lastStage string
//Loop over stages to find last successful one with at least one parser
for i := len(stages) - 2; i >= 0; i-- {
if len(pdump[stages[i]]) != 0 {
lastStage = stages[i]
break
}
}
parsers := make([]string, 0, len(pdump[lastStage]))
for k := range pdump[lastStage] {
parsers = append(parsers, k)
}
sort.Strings(parsers)
if len(parsers) == 0 {
return nil, fmt.Errorf("no parser found. Please install the appropriate parser and retry")
}
lastParser := parsers[len(parsers)-1]
for idx, result := range pdump[lastStage][lastParser] {
if result.Evt.StrTime == "" {
log.Warningf("Line %d/%d is missing evt.StrTime. It is most likely a mistake as it will prevent your logs to be processed in time-machine/forensic mode.", idx, len(pdump[lastStage][lastParser]))
} else {
log.Debugf("Line %d/%d has evt.StrTime set to '%s'", idx, len(pdump[lastStage][lastParser]), result.Evt.StrTime)
}
}
return &pdump, nil
}
func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpOpts) {
//note : we can use line -> time as the unique identifier (of acquisition)
state := make(map[time.Time]map[string]map[string]ParserResult)
assoc := make(map[time.Time]string, 0)
parser_order := make(map[string][]string)
for stage, parsers := range parserResults {
//let's process parsers in the order according to idx
parser_order[stage] = make([]string, len(parsers))
for pname, parser := range parsers {
if len(parser) > 0 {
parser_order[stage][parser[0].Idx-1] = pname
}
}
for _, parser := range parser_order[stage] {
results := parsers[parser]
for _, parserRes := range results {
evt := parserRes.Evt
if _, ok := state[evt.Line.Time]; !ok {
state[evt.Line.Time] = make(map[string]map[string]ParserResult)
assoc[evt.Line.Time] = evt.Line.Raw
}
if _, ok := state[evt.Line.Time][stage]; !ok {
state[evt.Line.Time][stage] = make(map[string]ParserResult)
}
state[evt.Line.Time][stage][parser] = ParserResult{Evt: evt, Success: parserRes.Success}
}
}
}
for bname, evtlist := range bucketPour {
for _, evt := range evtlist {
if evt.Line.Raw == "" {
continue
}
//it might be bucket overflow being reprocessed, skip this
if _, ok := state[evt.Line.Time]; !ok {
state[evt.Line.Time] = make(map[string]map[string]ParserResult)
assoc[evt.Line.Time] = evt.Line.Raw
}
//there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase
//we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered
if _, ok := state[evt.Line.Time]["buckets"]; !ok {
state[evt.Line.Time]["buckets"] = make(map[string]ParserResult)
}
state[evt.Line.Time]["buckets"][bname] = ParserResult{Success: true}
}
}
yellow := color.New(color.FgYellow).SprintFunc()
red := color.New(color.FgRed).SprintFunc()
green := color.New(color.FgGreen).SprintFunc()
whitelistReason := ""
//get each line
for tstamp, rawstr := range assoc {
if opts.SkipOk {
if _, ok := state[tstamp]["buckets"]["OK"]; ok {
continue
}
}
fmt.Printf("line: %s\n", rawstr)
skeys := make([]string, 0, len(state[tstamp]))
for k := range state[tstamp] {
//there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase
//we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered
if k == "buckets" {
continue
}
skeys = append(skeys, k)
}
sort.Strings(skeys)
// iterate stage
var prevItem types.Event
for _, stage := range skeys {
parsers := state[tstamp][stage]
sep := "├"
presep := "|"
fmt.Printf("\t%s %s\n", sep, stage)
for idx, parser := range parser_order[stage] {
res := parsers[parser].Success
sep := "├"
if idx == len(parser_order[stage])-1 {
sep = "└"
}
created := 0
updated := 0
deleted := 0
whitelisted := false
changeStr := ""
detailsDisplay := ""
if res {
changelog, _ := diff.Diff(prevItem, parsers[parser].Evt)
for _, change := range changelog {
switch change.Type {
case "create":
created++
detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), green(change.To))
case "update":
detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s -> %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), change.From, yellow(change.To))
if change.Path[0] == "Whitelisted" && change.To == true {
whitelisted = true
if whitelistReason == "" {
whitelistReason = parsers[parser].Evt.WhitelistReason
}
}
updated++
case "delete":
deleted++
detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s\n", presep, sep, change.Type, red(strings.Join(change.Path, ".")))
}
}
prevItem = parsers[parser].Evt
}
if created > 0 {
changeStr += green(fmt.Sprintf("+%d", created))
}
if updated > 0 {
if len(changeStr) > 0 {
changeStr += " "
}
changeStr += yellow(fmt.Sprintf("~%d", updated))
}
if deleted > 0 {
if len(changeStr) > 0 {
changeStr += " "
}
changeStr += red(fmt.Sprintf("-%d", deleted))
}
if whitelisted {
if len(changeStr) > 0 {
changeStr += " "
}
changeStr += red("[whitelisted]")
}
if changeStr == "" {
changeStr = yellow("unchanged")
}
if res {
fmt.Printf("\t%s\t%s %s %s (%s)\n", presep, sep, emoji.GreenCircle, parser, changeStr)
if opts.Details {
fmt.Print(detailsDisplay)
}
} else if opts.ShowNotOkParsers {
fmt.Printf("\t%s\t%s %s %s\n", presep, sep, emoji.RedCircle, parser)
}
}
}
sep := "└"
if len(state[tstamp]["buckets"]) > 0 {
sep = "├"
}
//did the event enter the bucket pour phase ?
if _, ok := state[tstamp]["buckets"]["OK"]; ok {
fmt.Printf("\t%s-------- parser success %s\n", sep, emoji.GreenCircle)
} else if whitelistReason != "" {
fmt.Printf("\t%s-------- parser success, ignored by whitelist (%s) %s\n", sep, whitelistReason, emoji.GreenCircle)
} else {
fmt.Printf("\t%s-------- parser failure %s\n", sep, emoji.RedCircle)
}
//now print bucket info
if len(state[tstamp]["buckets"]) > 0 {
fmt.Printf("\t├ Scenarios\n")
}
bnames := make([]string, 0, len(state[tstamp]["buckets"]))
for k := range state[tstamp]["buckets"] {
//there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase
//we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered
if k == "OK" {
continue
}
bnames = append(bnames, k)
}
sort.Strings(bnames)
for idx, bname := range bnames {
sep := "├"
if idx == len(bnames)-1 {
sep = "└"
}
fmt.Printf("\t\t%s %s %s\n", sep, emoji.GreenCircle, bname)
}
fmt.Println()
}
}

View file

@ -11,6 +11,7 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/crowdsecurity/go-cs-lib/ptr"
@ -78,6 +79,7 @@ func smokeHandler(req *http.Request) *http.Response {
}
requestedIP := strings.Split(req.URL.Path, "/")[3]
sample, ok := sampledata[requestedIP]
if !ok {
return &http.Response{
@ -109,9 +111,11 @@ func smokeHandler(req *http.Request) *http.Response {
func TestNillClient(t *testing.T) {
defer ShutdownCrowdsecCTI()
if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); !errors.Is(err, cticlient.ErrDisabled) {
t.Fatalf("failed to init CTI : %s", err)
}
item, err := CrowdsecCTI("1.2.3.4")
assert.Equal(t, err, cticlient.ErrDisabled)
assert.Equal(t, item, &cticlient.SmokeItem{})
@ -119,6 +123,7 @@ func TestNillClient(t *testing.T) {
func TestInvalidAuth(t *testing.T) {
defer ShutdownCrowdsecCTI()
if err := InitCrowdsecCTI(ptr.Of("asdasd"), nil, nil, nil); err != nil {
t.Fatalf("failed to init CTI : %s", err)
}
@ -129,7 +134,7 @@ func TestInvalidAuth(t *testing.T) {
item, err := CrowdsecCTI("1.2.3.4")
assert.Equal(t, item, &cticlient.SmokeItem{})
assert.Equal(t, CTIApiEnabled, false)
assert.False(t, CTIApiEnabled)
assert.Equal(t, err, cticlient.ErrUnauthorized)
//CTI is now disabled, all requests should return empty
@ -139,14 +144,15 @@ func TestInvalidAuth(t *testing.T) {
item, err = CrowdsecCTI("1.2.3.4")
assert.Equal(t, item, &cticlient.SmokeItem{})
assert.Equal(t, CTIApiEnabled, false)
assert.False(t, CTIApiEnabled)
assert.Equal(t, err, cticlient.ErrDisabled)
}
func TestNoKey(t *testing.T) {
defer ShutdownCrowdsecCTI()
err := InitCrowdsecCTI(nil, nil, nil, nil)
assert.ErrorIs(t, err, cticlient.ErrDisabled)
require.ErrorIs(t, err, cticlient.ErrDisabled)
//Replace the client created by InitCrowdsecCTI with one that uses a custom transport
ctiClient = cticlient.NewCrowdsecCTIClient(cticlient.WithAPIKey("asdasd"), cticlient.WithHTTPClient(&http.Client{
Transport: RoundTripFunc(smokeHandler),
@ -154,12 +160,13 @@ func TestNoKey(t *testing.T) {
item, err := CrowdsecCTI("1.2.3.4")
assert.Equal(t, item, &cticlient.SmokeItem{})
assert.Equal(t, CTIApiEnabled, false)
assert.False(t, CTIApiEnabled)
assert.Equal(t, err, cticlient.ErrDisabled)
}
func TestCache(t *testing.T) {
defer ShutdownCrowdsecCTI()
cacheDuration := 1 * time.Second
if err := InitCrowdsecCTI(ptr.Of(validApiKey), &cacheDuration, nil, nil); err != nil {
t.Fatalf("failed to init CTI : %s", err)
@ -172,28 +179,27 @@ func TestCache(t *testing.T) {
item, err := CrowdsecCTI("1.2.3.4")
ctiResp := item.(*cticlient.SmokeItem)
assert.Equal(t, "1.2.3.4", ctiResp.Ip)
assert.Equal(t, CTIApiEnabled, true)
assert.Equal(t, CTICache.Len(true), 1)
assert.Equal(t, err, nil)
assert.True(t, CTIApiEnabled)
assert.Equal(t, 1, CTICache.Len(true))
require.NoError(t, err)
item, err = CrowdsecCTI("1.2.3.4")
ctiResp = item.(*cticlient.SmokeItem)
assert.Equal(t, "1.2.3.4", ctiResp.Ip)
assert.Equal(t, CTIApiEnabled, true)
assert.Equal(t, CTICache.Len(true), 1)
assert.Equal(t, err, nil)
assert.True(t, CTIApiEnabled)
assert.Equal(t, 1, CTICache.Len(true))
require.NoError(t, err)
time.Sleep(2 * time.Second)
assert.Equal(t, CTICache.Len(true), 0)
assert.Equal(t, 0, CTICache.Len(true))
item, err = CrowdsecCTI("1.2.3.4")
ctiResp = item.(*cticlient.SmokeItem)
assert.Equal(t, "1.2.3.4", ctiResp.Ip)
assert.Equal(t, CTIApiEnabled, true)
assert.Equal(t, CTICache.Len(true), 1)
assert.Equal(t, err, nil)
assert.True(t, CTIApiEnabled)
assert.Equal(t, 1, CTICache.Len(true))
require.NoError(t, err)
}

View file

@ -28,17 +28,18 @@ var (
func getDBClient(t *testing.T) *database.Client {
t.Helper()
dbPath, err := os.CreateTemp("", "*sqlite")
require.NoError(t, err)
testDbClient, err := database.NewClient(&csconfig.DatabaseCfg{
testDBClient, err := database.NewClient(&csconfig.DatabaseCfg{
Type: "sqlite",
DbName: "crowdsec",
DbPath: dbPath.Name(),
})
require.NoError(t, err)
return testDbClient
return testDBClient
}
func TestVisitor(t *testing.T) {
@ -109,17 +110,18 @@ func TestVisitor(t *testing.T) {
if err != nil && test.err == nil {
log.Fatalf("run : %s", err)
}
if isOk := assert.Equal(t, test.result, result); !isOk {
t.Fatalf("test '%s' : NOK", test.filter)
}
}
}
}
func TestMatch(t *testing.T) {
err := Init(nil)
require.NoError(t, err)
tests := []struct {
glob string
val string
@ -149,12 +151,15 @@ func TestMatch(t *testing.T) {
"pattern": test.glob,
"name": test.val,
}
vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
if err != nil {
t.Fatalf("pattern:%s val:%s NOK %s", test.glob, test.val, err)
}
ret, err := expr.Run(vm, env)
assert.NoError(t, err)
require.NoError(t, err)
if isOk := assert.Equal(t, test.ret, ret); !isOk {
t.Fatalf("pattern:%s val:%s NOK %t != %t", test.glob, test.val, ret, test.ret)
}
@ -194,10 +199,10 @@ func TestDistanceHelper(t *testing.T) {
}
ret, err := expr.Run(vm, env)
if test.valid {
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, test.dist, ret)
} else {
assert.NotNil(t, err)
require.Error(t, err)
}
})
}
@ -283,10 +288,12 @@ func TestRegexpInFile(t *testing.T) {
if err != nil {
log.Fatal(err)
}
result, err := expr.Run(compiledFilter, map[string]interface{}{})
if err != nil {
log.Fatal(err)
}
if isOk := assert.Equal(t, test.result, result); !isOk {
t.Fatalf("test '%s' : NOK", test.name)
}
@ -335,28 +342,34 @@ func TestFileInit(t *testing.T) {
if err != nil {
log.Fatal(err)
}
if test.types == "string" {
switch test.types {
case "string":
if _, ok := dataFile[test.filename]; !ok {
t.Fatalf("test '%s' : NOK", test.name)
}
if isOk := assert.Equal(t, test.result, len(dataFile[test.filename])); !isOk {
if isOk := assert.Len(t, dataFile[test.filename], test.result); !isOk {
t.Fatalf("test '%s' : NOK", test.name)
}
} else if test.types == "regex" {
case "regex":
if _, ok := dataFileRegex[test.filename]; !ok {
t.Fatalf("test '%s' : NOK", test.name)
}
if isOk := assert.Equal(t, test.result, len(dataFileRegex[test.filename])); !isOk {
if isOk := assert.Len(t, dataFileRegex[test.filename], test.result); !isOk {
t.Fatalf("test '%s' : NOK", test.name)
}
} else {
default:
if _, ok := dataFileRegex[test.filename]; ok {
t.Fatalf("test '%s' : NOK", test.name)
}
if _, ok := dataFile[test.filename]; ok {
t.Fatalf("test '%s' : NOK", test.name)
}
}
log.Printf("test '%s' : OK", test.name)
}
}
@ -408,21 +421,23 @@ func TestFile(t *testing.T) {
if err != nil {
log.Fatal(err)
}
result, err := expr.Run(compiledFilter, map[string]interface{}{})
if err != nil {
log.Fatal(err)
}
if isOk := assert.Equal(t, test.result, result); !isOk {
t.Fatalf("test '%s' : NOK", test.name)
}
log.Printf("test '%s' : OK", test.name)
log.Printf("test '%s' : OK", test.name)
}
}
func TestIpInRange(t *testing.T) {
err := Init(nil)
assert.NoError(t, err)
require.NoError(t, err)
tests := []struct {
name string
env map[string]interface{}
@ -470,12 +485,11 @@ func TestIpInRange(t *testing.T) {
require.Equal(t, test.result, output)
log.Printf("test '%s' : OK", test.name)
}
}
func TestIpToRange(t *testing.T) {
err := Init(nil)
assert.NoError(t, err)
require.NoError(t, err)
tests := []struct {
name string
env map[string]interface{}
@ -543,13 +557,11 @@ func TestIpToRange(t *testing.T) {
require.Equal(t, test.result, output)
log.Printf("test '%s' : OK", test.name)
}
}
func TestAtof(t *testing.T) {
err := Init(nil)
assert.NoError(t, err)
require.NoError(t, err)
tests := []struct {
name string
@ -593,13 +605,14 @@ func TestUpper(t *testing.T) {
}
err := Init(nil)
assert.NoError(t, err)
require.NoError(t, err)
vm, err := expr.Compile("Upper(testStr)", GetExprOptions(env)...)
assert.NoError(t, err)
require.NoError(t, err)
out, err := expr.Run(vm, env)
assert.NoError(t, err)
require.NoError(t, err)
v, ok := out.(string)
if !ok {
t.Fatalf("Upper() should return a string")
@ -612,6 +625,7 @@ func TestUpper(t *testing.T) {
func TestTimeNow(t *testing.T) {
now, _ := TimeNow()
ti, err := time.Parse(time.RFC3339, now.(string))
if err != nil {
t.Fatalf("Error parsing the return value of TimeNow: %s", err)
@ -620,6 +634,7 @@ func TestTimeNow(t *testing.T) {
if -1*time.Until(ti) > time.Second {
t.Fatalf("TimeNow func should return time.Now().UTC()")
}
log.Printf("test 'TimeNow()' : OK")
}
@ -894,15 +909,14 @@ func TestLower(t *testing.T) {
}
func TestGetDecisionsCount(t *testing.T) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
existingIP := "1.2.3.4"
unknownIP := "1.2.3.5"
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP)
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP)
if err != nil {
t.Errorf("unable to convert '%s' to int: %s", existingIP, err)
}
// Add sample data to DB
dbClient = getDBClient(t)
@ -921,11 +935,11 @@ func TestGetDecisionsCount(t *testing.T) {
SaveX(context.Background())
if decision == nil {
assert.Error(t, errors.Errorf("Failed to create sample decision"))
require.Error(t, errors.Errorf("Failed to create sample decision"))
}
err = Init(dbClient)
assert.NoError(t, err)
require.NoError(t, err)
tests := []struct {
name string
@ -982,12 +996,10 @@ func TestGetDecisionsCount(t *testing.T) {
}
}
func TestGetDecisionsSinceCount(t *testing.T) {
var err error
var start_ip, start_sfx, end_ip, end_sfx int64
var ip_sz int
existingIP := "1.2.3.4"
unknownIP := "1.2.3.5"
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err = types.Addr2Ints(existingIP)
ip_sz, start_ip, start_sfx, end_ip, end_sfx, err := types.Addr2Ints(existingIP)
if err != nil {
t.Errorf("unable to convert '%s' to int: %s", existingIP, err)
}
@ -1008,8 +1020,9 @@ func TestGetDecisionsSinceCount(t *testing.T) {
SetOrigin("CAPI").
SaveX(context.Background())
if decision == nil {
assert.Error(t, errors.Errorf("Failed to create sample decision"))
require.Error(t, errors.Errorf("Failed to create sample decision"))
}
decision2 := dbClient.Ent.Decision.Create().
SetCreatedAt(time.Now().AddDate(0, 0, -1)).
SetUntil(time.Now().AddDate(0, 0, -1)).
@ -1024,12 +1037,13 @@ func TestGetDecisionsSinceCount(t *testing.T) {
SetValue(existingIP).
SetOrigin("CAPI").
SaveX(context.Background())
if decision2 == nil {
assert.Error(t, errors.Errorf("Failed to create sample decision"))
require.Error(t, errors.Errorf("Failed to create sample decision"))
}
err = Init(dbClient)
assert.NoError(t, err)
require.NoError(t, err)
tests := []struct {
name string
@ -1152,6 +1166,7 @@ func TestIsIp(t *testing.T) {
if err := Init(nil); err != nil {
log.Fatal(err)
}
tests := []struct {
name string
expr string
@ -1235,17 +1250,18 @@ func TestIsIp(t *testing.T) {
expectedBuildErr: true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...)
if tc.expectedBuildErr {
assert.Error(t, err)
require.Error(t, err)
return
}
assert.NoError(t, err)
require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
assert.NoError(t, err)
require.NoError(t, err)
assert.IsType(t, tc.expected, output)
assert.Equal(t, tc.expected, output.(bool))
})
@ -1255,6 +1271,7 @@ func TestIsIp(t *testing.T) {
func TestToString(t *testing.T) {
err := Init(nil)
require.NoError(t, err)
tests := []struct {
name string
value interface{}
@ -1290,9 +1307,9 @@ func TestToString(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...)
assert.NoError(t, err)
require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
assert.NoError(t, err)
require.NoError(t, err)
require.Equal(t, tc.expected, output)
})
}
@ -1338,16 +1355,16 @@ func TestB64Decode(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
vm, err := expr.Compile(tc.expr, GetExprOptions(map[string]interface{}{"value": tc.value})...)
if tc.expectedBuildErr {
assert.Error(t, err)
require.Error(t, err)
return
}
assert.NoError(t, err)
require.NoError(t, err)
output, err := expr.Run(vm, map[string]interface{}{"value": tc.value})
if tc.expectedRuntimeErr {
assert.Error(t, err)
require.Error(t, err)
return
}
assert.NoError(t, err)
require.NoError(t, err)
require.Equal(t, tc.expected, output)
})
}
@ -1412,9 +1429,9 @@ func TestParseKv(t *testing.T) {
"out": outMap,
}
vm, err := expr.Compile(tc.expr, GetExprOptions(env)...)
assert.NoError(t, err)
require.NoError(t, err)
_, err = expr.Run(vm, env)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, tc.expected, outMap["a"])
})
}

View file

@ -7,6 +7,7 @@ import (
"github.com/antonmedv/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestJsonExtract(t *testing.T) {
@ -56,14 +57,14 @@ func TestJsonExtract(t *testing.T) {
"target": test.targetField,
}
vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err)
require.NoError(t, err)
out, err := expr.Run(vm, env)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, test.expectResult, out)
})
}
}
func TestJsonExtractUnescape(t *testing.T) {
if err := Init(nil); err != nil {
log.Fatal(err)
@ -104,9 +105,9 @@ func TestJsonExtractUnescape(t *testing.T) {
"target": test.targetField,
}
vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err)
require.NoError(t, err)
out, err := expr.Run(vm, env)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, test.expectResult, out)
})
}
@ -167,9 +168,9 @@ func TestJsonExtractSlice(t *testing.T) {
"target": test.targetField,
}
vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err)
require.NoError(t, err)
out, err := expr.Run(vm, env)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, test.expectResult, out)
})
}
@ -223,9 +224,9 @@ func TestJsonExtractObject(t *testing.T) {
"target": test.targetField,
}
vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err)
require.NoError(t, err)
out, err := expr.Run(vm, env)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, test.expectResult, out)
})
}
@ -233,7 +234,8 @@ func TestJsonExtractObject(t *testing.T) {
func TestToJson(t *testing.T) {
err := Init(nil)
assert.NoError(t, err)
require.NoError(t, err)
tests := []struct {
name string
obj interface{}
@ -298,9 +300,9 @@ func TestToJson(t *testing.T) {
"obj": test.obj,
}
vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err)
require.NoError(t, err)
out, err := expr.Run(vm, env)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, test.expectResult, out)
})
}
@ -308,7 +310,8 @@ func TestToJson(t *testing.T) {
func TestUnmarshalJSON(t *testing.T) {
err := Init(nil)
assert.NoError(t, err)
require.NoError(t, err)
tests := []struct {
name string
json string
@ -361,11 +364,10 @@ func TestUnmarshalJSON(t *testing.T) {
"out": outMap,
}
vm, err := expr.Compile(test.expr, GetExprOptions(env)...)
assert.NoError(t, err)
require.NoError(t, err)
_, err = expr.Run(vm, env)
assert.NoError(t, err)
require.NoError(t, err)
assert.Equal(t, test.expectResult, outMap["a"])
})
}
}

View file

@ -9,6 +9,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/appsec/appsec_rule"
"github.com/crowdsecurity/crowdsec/pkg/cwhub"
"github.com/crowdsecurity/go-cs-lib/maptools"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
)
@ -25,7 +26,7 @@ func (h *HubTest) GetAppsecCoverage() ([]Coverage, error) {
}
// populate from hub, iterate in alphabetical order
pkeys := sortedMapKeys(h.HubIndex.GetItemMap(cwhub.APPSEC_RULES))
pkeys := maptools.SortedKeys(h.HubIndex.GetItemMap(cwhub.APPSEC_RULES))
coverage := make([]Coverage, len(pkeys))
for i, name := range pkeys {
@ -84,7 +85,7 @@ func (h *HubTest) GetParsersCoverage() ([]Coverage, error) {
}
// populate from hub, iterate in alphabetical order
pkeys := sortedMapKeys(h.HubIndex.GetItemMap(cwhub.PARSERS))
pkeys := maptools.SortedKeys(h.HubIndex.GetItemMap(cwhub.PARSERS))
coverage := make([]Coverage, len(pkeys))
for i, name := range pkeys {
@ -170,7 +171,7 @@ func (h *HubTest) GetScenariosCoverage() ([]Coverage, error) {
}
// populate from hub, iterate in alphabetical order
pkeys := sortedMapKeys(h.HubIndex.GetItemMap(cwhub.SCENARIOS))
pkeys := maptools.SortedKeys(h.HubIndex.GetItemMap(cwhub.SCENARIOS))
coverage := make([]Coverage, len(pkeys))
for i, name := range pkeys {

View file

@ -3,21 +3,16 @@ package hubtest
import (
"bufio"
"fmt"
"io"
"os"
"sort"
"strings"
"time"
"github.com/antonmedv/expr"
"github.com/enescakir/emoji"
"github.com/fatih/color"
diff "github.com/r3labs/diff/v2"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"github.com/crowdsecurity/crowdsec/pkg/dumps"
"github.com/crowdsecurity/crowdsec/pkg/exprhelpers"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/crowdsecurity/go-cs-lib/maptools"
)
type AssertFail struct {
@ -34,16 +29,9 @@ type ParserAssert struct {
NbAssert int
Fails []AssertFail
Success bool
TestData *ParserResults
TestData *dumps.ParserResults
}
type ParserResult struct {
Evt types.Event
Success bool
}
type ParserResults map[string]map[string][]ParserResult
func NewParserAssert(file string) *ParserAssert {
ParserAssert := &ParserAssert{
File: file,
@ -51,7 +39,7 @@ func NewParserAssert(file string) *ParserAssert {
Success: false,
Fails: make([]AssertFail, 0),
AutoGenAssert: false,
TestData: &ParserResults{},
TestData: &dumps.ParserResults{},
}
return ParserAssert
@ -69,7 +57,7 @@ func (p *ParserAssert) AutoGenFromFile(filename string) (string, error) {
}
func (p *ParserAssert) LoadTest(filename string) error {
parserDump, err := LoadParserDump(filename)
parserDump, err := dumps.LoadParserDump(filename)
if err != nil {
return fmt.Errorf("loading parser dump file: %+v", err)
}
@ -229,13 +217,13 @@ func (p *ParserAssert) AutoGenParserAssert() string {
ret := fmt.Sprintf("len(results) == %d\n", len(*p.TestData))
//sort map keys for consistent order
stages := sortedMapKeys(*p.TestData)
stages := maptools.SortedKeys(*p.TestData)
for _, stage := range stages {
parsers := (*p.TestData)[stage]
//sort map keys for consistent order
pnames := sortedMapKeys(parsers)
pnames := maptools.SortedKeys(parsers)
for _, parser := range pnames {
presults := parsers[parser]
@ -248,7 +236,7 @@ func (p *ParserAssert) AutoGenParserAssert() string {
continue
}
for _, pkey := range sortedMapKeys(result.Evt.Parsed) {
for _, pkey := range maptools.SortedKeys(result.Evt.Parsed) {
pval := result.Evt.Parsed[pkey]
if pval == "" {
continue
@ -257,7 +245,7 @@ func (p *ParserAssert) AutoGenParserAssert() string {
ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Parsed["%s"] == "%s"`+"\n", stage, parser, pidx, pkey, Escape(pval))
}
for _, mkey := range sortedMapKeys(result.Evt.Meta) {
for _, mkey := range maptools.SortedKeys(result.Evt.Meta) {
mval := result.Evt.Meta[mkey]
if mval == "" {
continue
@ -266,7 +254,7 @@ func (p *ParserAssert) AutoGenParserAssert() string {
ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Meta["%s"] == "%s"`+"\n", stage, parser, pidx, mkey, Escape(mval))
}
for _, ekey := range sortedMapKeys(result.Evt.Enriched) {
for _, ekey := range maptools.SortedKeys(result.Evt.Enriched) {
eval := result.Evt.Enriched[ekey]
if eval == "" {
continue
@ -275,7 +263,7 @@ func (p *ParserAssert) AutoGenParserAssert() string {
ret += fmt.Sprintf(`results["%s"]["%s"][%d].Evt.Enriched["%s"] == "%s"`+"\n", stage, parser, pidx, ekey, Escape(eval))
}
for _, ukey := range sortedMapKeys(result.Evt.Unmarshaled) {
for _, ukey := range maptools.SortedKeys(result.Evt.Unmarshaled) {
uval := result.Evt.Unmarshaled[ukey]
if uval == "" {
continue
@ -328,288 +316,3 @@ func (p *ParserAssert) buildUnmarshaledAssert(ekey string, eval interface{}) []s
return ret
}
func LoadParserDump(filepath string) (*ParserResults, error) {
dumpData, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer dumpData.Close()
results, err := io.ReadAll(dumpData)
if err != nil {
return nil, err
}
pdump := ParserResults{}
if err := yaml.Unmarshal(results, &pdump); err != nil {
return nil, err
}
/* we know that some variables should always be set,
let's check if they're present in last parser output of last stage */
stages := sortedMapKeys(pdump)
var lastStage string
//Loop over stages to find last successful one with at least one parser
for i := len(stages) - 2; i >= 0; i-- {
if len(pdump[stages[i]]) != 0 {
lastStage = stages[i]
break
}
}
parsers := make([]string, 0, len(pdump[lastStage]))
for k := range pdump[lastStage] {
parsers = append(parsers, k)
}
sort.Strings(parsers)
if len(parsers) == 0 {
return nil, fmt.Errorf("no parser found. Please install the appropriate parser and retry")
}
lastParser := parsers[len(parsers)-1]
for idx, result := range pdump[lastStage][lastParser] {
if result.Evt.StrTime == "" {
log.Warningf("Line %d/%d is missing evt.StrTime. It is most likely a mistake as it will prevent your logs to be processed in time-machine/forensic mode.", idx, len(pdump[lastStage][lastParser]))
} else {
log.Debugf("Line %d/%d has evt.StrTime set to '%s'", idx, len(pdump[lastStage][lastParser]), result.Evt.StrTime)
}
}
return &pdump, nil
}
type DumpOpts struct {
Details bool
SkipOk bool
ShowNotOkParsers bool
}
func DumpTree(parserResults ParserResults, bucketPour BucketPourInfo, opts DumpOpts) {
//note : we can use line -> time as the unique identifier (of acquisition)
state := make(map[time.Time]map[string]map[string]ParserResult)
assoc := make(map[time.Time]string, 0)
for stage, parsers := range parserResults {
for parser, results := range parsers {
for _, parserRes := range results {
evt := parserRes.Evt
if _, ok := state[evt.Line.Time]; !ok {
state[evt.Line.Time] = make(map[string]map[string]ParserResult)
assoc[evt.Line.Time] = evt.Line.Raw
}
if _, ok := state[evt.Line.Time][stage]; !ok {
state[evt.Line.Time][stage] = make(map[string]ParserResult)
}
state[evt.Line.Time][stage][parser] = ParserResult{Evt: evt, Success: parserRes.Success}
}
}
}
for bname, evtlist := range bucketPour {
for _, evt := range evtlist {
if evt.Line.Raw == "" {
continue
}
//it might be bucket overflow being reprocessed, skip this
if _, ok := state[evt.Line.Time]; !ok {
state[evt.Line.Time] = make(map[string]map[string]ParserResult)
assoc[evt.Line.Time] = evt.Line.Raw
}
//there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase
//we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered
if _, ok := state[evt.Line.Time]["buckets"]; !ok {
state[evt.Line.Time]["buckets"] = make(map[string]ParserResult)
}
state[evt.Line.Time]["buckets"][bname] = ParserResult{Success: true}
}
}
yellow := color.New(color.FgYellow).SprintFunc()
red := color.New(color.FgRed).SprintFunc()
green := color.New(color.FgGreen).SprintFunc()
whitelistReason := ""
//get each line
for tstamp, rawstr := range assoc {
if opts.SkipOk {
if _, ok := state[tstamp]["buckets"]["OK"]; ok {
continue
}
}
fmt.Printf("line: %s\n", rawstr)
skeys := make([]string, 0, len(state[tstamp]))
for k := range state[tstamp] {
//there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase
//we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered
if k == "buckets" {
continue
}
skeys = append(skeys, k)
}
sort.Strings(skeys)
// iterate stage
var prevItem types.Event
for _, stage := range skeys {
parsers := state[tstamp][stage]
sep := "├"
presep := "|"
fmt.Printf("\t%s %s\n", sep, stage)
pkeys := sortedMapKeys(parsers)
for idx, parser := range pkeys {
res := parsers[parser].Success
sep := "├"
if idx == len(pkeys)-1 {
sep = "└"
}
created := 0
updated := 0
deleted := 0
whitelisted := false
changeStr := ""
detailsDisplay := ""
if res {
changelog, _ := diff.Diff(prevItem, parsers[parser].Evt)
for _, change := range changelog {
switch change.Type {
case "create":
created++
detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), green(change.To))
case "update":
detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s : %s -> %s\n", presep, sep, change.Type, strings.Join(change.Path, "."), change.From, yellow(change.To))
if change.Path[0] == "Whitelisted" && change.To == true {
whitelisted = true
if whitelistReason == "" {
whitelistReason = parsers[parser].Evt.WhitelistReason
}
}
updated++
case "delete":
deleted++
detailsDisplay += fmt.Sprintf("\t%s\t\t%s %s evt.%s\n", presep, sep, change.Type, red(strings.Join(change.Path, ".")))
}
}
prevItem = parsers[parser].Evt
}
if created > 0 {
changeStr += green(fmt.Sprintf("+%d", created))
}
if updated > 0 {
if len(changeStr) > 0 {
changeStr += " "
}
changeStr += yellow(fmt.Sprintf("~%d", updated))
}
if deleted > 0 {
if len(changeStr) > 0 {
changeStr += " "
}
changeStr += red(fmt.Sprintf("-%d", deleted))
}
if whitelisted {
if len(changeStr) > 0 {
changeStr += " "
}
changeStr += red("[whitelisted]")
}
if changeStr == "" {
changeStr = yellow("unchanged")
}
if res {
fmt.Printf("\t%s\t%s %s %s (%s)\n", presep, sep, emoji.GreenCircle, parser, changeStr)
if opts.Details {
fmt.Print(detailsDisplay)
}
} else if opts.ShowNotOkParsers {
fmt.Printf("\t%s\t%s %s %s\n", presep, sep, emoji.RedCircle, parser)
}
}
}
sep := "└"
if len(state[tstamp]["buckets"]) > 0 {
sep = "├"
}
//did the event enter the bucket pour phase ?
if _, ok := state[tstamp]["buckets"]["OK"]; ok {
fmt.Printf("\t%s-------- parser success %s\n", sep, emoji.GreenCircle)
} else if whitelistReason != "" {
fmt.Printf("\t%s-------- parser success, ignored by whitelist (%s) %s\n", sep, whitelistReason, emoji.GreenCircle)
} else {
fmt.Printf("\t%s-------- parser failure %s\n", sep, emoji.RedCircle)
}
//now print bucket info
if len(state[tstamp]["buckets"]) > 0 {
fmt.Printf("\t├ Scenarios\n")
}
bnames := make([]string, 0, len(state[tstamp]["buckets"]))
for k := range state[tstamp]["buckets"] {
//there is a trick : to know if an event successfully exit the parsers, we check if it reached the pour() phase
//we thus use a fake stage "buckets" and a fake parser "OK" to know if it entered
if k == "OK" {
continue
}
bnames = append(bnames, k)
}
sort.Strings(bnames)
for idx, bname := range bnames {
sep := "├"
if idx == len(bnames)-1 {
sep = "└"
}
fmt.Printf("\t\t%s %s %s\n", sep, emoji.GreenCircle, bname)
}
fmt.Println()
}
}

View file

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2"
"github.com/crowdsecurity/crowdsec/pkg/dumps"
"github.com/crowdsecurity/crowdsec/pkg/exprhelpers"
"github.com/crowdsecurity/crowdsec/pkg/types"
)
@ -24,11 +25,10 @@ type ScenarioAssert struct {
Fails []AssertFail
Success bool
TestData *BucketResults
PourData *BucketPourInfo
PourData *dumps.BucketPourInfo
}
type BucketResults []types.Event
type BucketPourInfo map[string][]types.Event
func NewScenarioAssert(file string) *ScenarioAssert {
ScenarioAssert := &ScenarioAssert{
@ -38,7 +38,7 @@ func NewScenarioAssert(file string) *ScenarioAssert {
Fails: make([]AssertFail, 0),
AutoGenAssert: false,
TestData: &BucketResults{},
PourData: &BucketPourInfo{},
PourData: &dumps.BucketPourInfo{},
}
return ScenarioAssert
@ -64,7 +64,7 @@ func (s *ScenarioAssert) LoadTest(filename string, bucketpour string) error {
s.TestData = bucketDump
if bucketpour != "" {
pourDump, err := LoadBucketPourDump(bucketpour)
pourDump, err := dumps.LoadBucketPourDump(bucketpour)
if err != nil {
return fmt.Errorf("loading bucket pour dump file '%s': %+v", filename, err)
}
@ -252,27 +252,6 @@ func (b BucketResults) Swap(i, j int) {
b[i], b[j] = b[j], b[i]
}
func LoadBucketPourDump(filepath string) (*BucketPourInfo, error) {
dumpData, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer dumpData.Close()
results, err := io.ReadAll(dumpData)
if err != nil {
return nil, err
}
var bucketDump BucketPourInfo
if err := yaml.Unmarshal(results, &bucketDump); err != nil {
return nil, err
}
return &bucketDump, nil
}
func LoadScenarioDump(filepath string) (*BucketResults, error) {
dumpData, err := os.Open(filepath)
if err != nil {

View file

@ -5,21 +5,25 @@ import (
"net"
"os"
"path/filepath"
"sort"
"time"
log "github.com/sirupsen/logrus"
)
func sortedMapKeys[V any](m map[string]V) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
func IsAlive(target string) (bool, error) {
start := time.Now()
for {
conn, err := net.Dial("tcp", target)
if err == nil {
log.Debugf("'%s' is up after %s", target, time.Since(start))
conn.Close()
return true, nil
}
time.Sleep(500 * time.Millisecond)
if time.Since(start) > 10*time.Second {
return false, fmt.Errorf("took more than 10s for %s to be available", target)
}
}
sort.Strings(keys)
return keys
}
func Copy(src string, dst string) error {
@ -110,19 +114,3 @@ func CopyDir(src string, dest string) error {
return nil
}
func IsAlive(target string) (bool, error) {
start := time.Now()
for {
conn, err := net.Dial("tcp", target)
if err == nil {
log.Debugf("'%s' is up after %s", target, time.Since(start))
conn.Close()
return true, nil
}
time.Sleep(500 * time.Millisecond)
if time.Since(start) > 10*time.Second {
return false, fmt.Errorf("took more than 10s for %s to be available", target)
}
}
}

View file

@ -18,6 +18,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/crowdsec/pkg/dumps"
"github.com/crowdsecurity/crowdsec/pkg/exprhelpers"
"github.com/crowdsecurity/crowdsec/pkg/types"
)
@ -229,14 +230,10 @@ func stageidx(stage string, stages []string) int {
return -1
}
type ParserResult struct {
Evt types.Event
Success bool
}
var ParseDump bool
var DumpFolder string
var StageParseCache map[string]map[string][]ParserResult
var StageParseCache dumps.ParserResults
var StageParseMutex sync.Mutex
func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error) {
@ -271,9 +268,9 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error)
if ParseDump {
if StageParseCache == nil {
StageParseMutex.Lock()
StageParseCache = make(map[string]map[string][]ParserResult)
StageParseCache["success"] = make(map[string][]ParserResult)
StageParseCache["success"][""] = make([]ParserResult, 0)
StageParseCache = make(dumps.ParserResults)
StageParseCache["success"] = make(map[string][]dumps.ParserResult)
StageParseCache["success"][""] = make([]dumps.ParserResult, 0)
StageParseMutex.Unlock()
}
}
@ -282,7 +279,7 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error)
if ParseDump {
StageParseMutex.Lock()
if _, ok := StageParseCache[stage]; !ok {
StageParseCache[stage] = make(map[string][]ParserResult)
StageParseCache[stage] = make(map[string][]dumps.ParserResult)
}
StageParseMutex.Unlock()
}
@ -322,13 +319,18 @@ func Parse(ctx UnixParserCtx, xp types.Event, nodes []Node) (types.Event, error)
}
clog.Tracef("node (%s) ret : %v", node.rn, ret)
if ParseDump {
parserIdxInStage := 0
StageParseMutex.Lock()
if len(StageParseCache[stage][node.Name]) == 0 {
StageParseCache[stage][node.Name] = make([]ParserResult, 0)
StageParseCache[stage][node.Name] = make([]dumps.ParserResult, 0)
parserIdxInStage = len(StageParseCache[stage])
} else {
parserIdxInStage = StageParseCache[stage][node.Name][0].Idx
}
StageParseMutex.Unlock()
evtcopy := deepcopy.Copy(event)
parserInfo := ParserResult{Evt: evtcopy.(types.Event), Success: ret}
parserInfo := dumps.ParserResult{Evt: evtcopy.(types.Event), Success: ret, Idx: parserIdxInStage}
StageParseMutex.Lock()
StageParseCache[stage][node.Name] = append(StageParseCache[stage][node.Name], parserInfo)
StageParseMutex.Unlock()

View file

@ -353,7 +353,7 @@ func TestUnitFound(t *testing.T) {
installed, err := env.UnitFound("crowdsec-setup-detect.service")
require.NoError(err)
require.Equal(true, installed)
require.True(installed)
}
// TODO apply rules to filter a list of Service structs
@ -566,8 +566,8 @@ func TestDetectForcedUnit(t *testing.T) {
func TestDetectForcedProcess(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping on windows")
// while looking for service wizard: rule 'ProcessRunning("foobar")': while looking up running processes: could not get Name: A device attached to the system is not functioning.
t.Skip("skipping on windows")
}
require := require.New(t)

View file

@ -73,7 +73,7 @@ func TestParseIPSources(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
ips := tt.evt.ParseIPSources()
assert.Equal(t, ips, tt.expected)
assert.Equal(t, tt.expected, ips)
})
}
}

View file

@ -46,7 +46,7 @@ func SetDefaultLoggerConfig(cfgMode string, cfgFolder string, cfgLevel log.Level
}
logLevel = cfgLevel
log.SetLevel(logLevel)
logFormatter = &log.TextFormatter{TimestampFormat: "2006-01-02 15:04:05", FullTimestamp: true, ForceColors: forceColors}
logFormatter = &log.TextFormatter{TimestampFormat: time.RFC3339, FullTimestamp: true, ForceColors: forceColors}
log.SetFormatter(logFormatter)
return nil
}

View file

@ -19,7 +19,35 @@ setup() {
#----------
@test "cscli capi status: fails without credentials" {
config_enable_capi
ONLINE_API_CREDENTIALS_YAML="$(config_get '.api.server.online_client.credentials_path')"
# bogus values, won't be used
echo '{"login":"login","password":"password","url":"url"}' > "${ONLINE_API_CREDENTIALS_YAML}"
config_set "$ONLINE_API_CREDENTIALS_YAML" 'del(.url)'
rune -1 cscli capi status
assert_stderr --partial "can't load CAPI credentials from '$ONLINE_API_CREDENTIALS_YAML' (missing url field)"
config_set "$ONLINE_API_CREDENTIALS_YAML" 'del(.password)'
rune -1 cscli capi status
assert_stderr --partial "can't load CAPI credentials from '$ONLINE_API_CREDENTIALS_YAML' (missing password field)"
config_set "$ONLINE_API_CREDENTIALS_YAML" 'del(.login)'
rune -1 cscli capi status
assert_stderr --partial "can't load CAPI credentials from '$ONLINE_API_CREDENTIALS_YAML' (missing login field)"
rm "${ONLINE_API_CREDENTIALS_YAML}"
rune -1 cscli capi status
assert_stderr --partial "failed to load Local API: loading online client credentials: open ${ONLINE_API_CREDENTIALS_YAML}: no such file or directory"
config_set 'del(.api.server.online_client)'
rune -1 cscli capi status
assert_stderr --partial "no configuration for Central API (CAPI) in '$CONFIG_YAML'"
}
@test "cscli capi status" {
./instance-data load
config_enable_capi
rune -0 cscli capi register --schmilblick githubciXXXXXXXXXXXXXXXXXXXXXXXX
rune -1 cscli capi status
@ -60,13 +88,6 @@ setup() {
assert_stderr --partial "You can successfully interact with Central API (CAPI)"
}
@test "cscli capi status: fails without credentials" {
ONLINE_API_CREDENTIALS_YAML="$(config_get '.api.server.online_client.credentials_path')"
rm "${ONLINE_API_CREDENTIALS_YAML}"
rune -1 cscli capi status
assert_stderr --partial "failed to load Local API: loading online client credentials: failed to read api server credentials configuration file '${ONLINE_API_CREDENTIALS_YAML}': open ${ONLINE_API_CREDENTIALS_YAML}: no such file or directory"
}
@test "capi register must be run from lapi" {
config_disable_lapi
rune -1 cscli capi register --schmilblick githubciXXXXXXXXXXXXXXXXXXXXXXXX

View file

@ -46,9 +46,21 @@ teardown() {
assert_stderr --partial "loading console context from $CONTEXT_YAML"
}
@test "no error if context file is missing but not explicitly set" {
config_set "del(.crowdsec_service.console_context_path)"
rune -0 rm -f "$CONTEXT_YAML"
rune -0 cscli lapi context status --error
refute_stderr
assert_output --partial "No context found on this agent."
rune -0 "$CROWDSEC" -t
refute_stderr --partial "no such file or directory"
}
@test "error if context file is explicitly set but does not exist" {
config_set ".crowdsec_service.console_context_path=strenv(CONTEXT_YAML)"
rune -0 rm -f "$CONTEXT_YAML"
rune -1 cscli lapi context status --error
assert_stderr --partial "context.yaml: no such file or directory"
rune -1 "$CROWDSEC" -t
assert_stderr --partial "while checking console_context_path: stat $CONTEXT_YAML: no such file or directory"
}

View file

@ -30,7 +30,7 @@ teardown() {
@test "capi_whitelists: file missing" {
rune -0 wait-for \
--err "capi whitelist file '$CAPI_WHITELISTS_YAML' does not exist" \
--err "while opening capi whitelist file: open $CAPI_WHITELISTS_YAML: no such file or directory" \
"${CROWDSEC}"
}

View file

@ -69,6 +69,16 @@ teardown() {
assert_output --partial 'crowdsecurity/iptables'
}
@test "cscli hub list (invalid index)" {
new_hub=$(jq <"$INDEX_PATH" '."appsec-rules"."crowdsecurity/vpatch-laravel-debug-mode".version="999"')
echo "$new_hub" >"$INDEX_PATH"
rune -0 cscli hub list --error
assert_stderr --partial "invalid hub item appsec-rules:crowdsecurity/vpatch-laravel-debug-mode: latest version missing from index"
rune -1 cscli appsec-rules install crowdsecurity/vpatch-laravel-debug-mode --force
assert_stderr --partial "error while installing 'crowdsecurity/vpatch-laravel-debug-mode': while downloading crowdsecurity/vpatch-laravel-debug-mode: latest hash missing from index"
}
@test "missing reference in hub index" {
new_hub=$(jq <"$INDEX_PATH" 'del(.parsers."crowdsecurity/smb-logs") | del (.scenarios."crowdsecurity/mysql-bf")')
echo "$new_hub" >"$INDEX_PATH"

View file

@ -102,7 +102,7 @@ log_info() {
log_fatal() {
msg=$1
date=$(date "+%Y-%m-%d %H:%M:%S")
echo -e "${RED}FATA${NC}[${date}] crowdsec_wizard: ${msg}" 1>&2
echo -e "${RED}FATA${NC}[${date}] crowdsec_wizard: ${msg}" 1>&2
exit 1
}
@ -129,16 +129,16 @@ log_dbg() {
detect_services () {
DETECTED_SERVICES=()
HMENU=()
#list systemd services
# list systemd services
SYSTEMD_SERVICES=`systemctl --state=enabled list-unit-files '*.service' | cut -d ' ' -f1`
#raw ps
# raw ps
PSAX=`ps ax -o comm=`
for SVC in ${SUPPORTED_SERVICES} ; do
log_dbg "Checking if service '${SVC}' is running (ps+systemd)"
for SRC in "${SYSTEMD_SERVICES}" "${PSAX}" ; do
echo ${SRC} | grep ${SVC} >/dev/null
if [ $? -eq 0 ]; then
#on centos, apache2 is named httpd
# on centos, apache2 is named httpd
if [[ ${SVC} == "httpd" ]] ; then
SVC="apache2";
fi
@ -152,12 +152,12 @@ detect_services () {
if [[ ${OSTYPE} == "linux-gnu" ]] || [[ ${OSTYPE} == "linux-gnueabihf" ]]; then
DETECTED_SERVICES+=("linux")
HMENU+=("linux" "on")
else
else
log_info "NOT A LINUX"
fi;
if [[ ${SILENT} == "false" ]]; then
#we put whiptail results in an array, notice the dark magic fd redirection
# we put whiptail results in an array, notice the dark magic fd redirection
DETECTED_SERVICES=($(whiptail --separate-output --noitem --ok-button Continue --title "Services to monitor" --checklist "Detected services, uncheck to ignore. Ignored services won't be monitored." 18 70 10 ${HMENU[@]} 3>&1 1>&2 2>&3))
if [ $? -eq 1 ]; then
log_err "user bailed out at services selection"
@ -189,28 +189,28 @@ log_locations[mysql]='/var/log/mysql/error.log'
log_locations[smb]='/var/log/samba*.log'
log_locations[linux]='/var/log/syslog,/var/log/kern.log,/var/log/messages'
#$1 is service name, such those in SUPPORTED_SERVICES
# $1 is service name, such those in SUPPORTED_SERVICES
find_logs_for() {
ret=""
x=${1}
#we have trailing and starting quotes because of whiptail
# we have trailing and starting quotes because of whiptail
SVC="${x%\"}"
SVC="${SVC#\"}"
DETECTED_LOGFILES=()
HMENU=()
#log_info "Searching logs for ${SVC} : ${log_locations[${SVC}]}"
# log_info "Searching logs for ${SVC} : ${log_locations[${SVC}]}"
#split the line into an array with ',' separator
# split the line into an array with ',' separator
OIFS=${IFS}
IFS=',' read -r -a a <<< "${log_locations[${SVC}]},"
IFS=${OIFS}
#readarray -td, a <<<"${log_locations[${SVC}]},"; unset 'a[-1]';
# readarray -td, a <<<"${log_locations[${SVC}]},"; unset 'a[-1]';
for poss_path in "${a[@]}"; do
#Split /var/log/nginx/*.log into '/var/log/nginx' and '*.log' so we can use find
# Split /var/log/nginx/*.log into '/var/log/nginx' and '*.log' so we can use find
path=${poss_path%/*}
fname=${poss_path##*/}
candidates=`find "${path}" -type f -mtime -5 -ctime -5 -name "$fname"`
#We have some candidates, add them
# We have some candidates, add them
for final_file in ${candidates} ; do
log_dbg "Found logs file for '${SVC}': ${final_file}"
DETECTED_LOGFILES+=(${final_file})
@ -249,12 +249,12 @@ install_collection() {
in_array $collection "${DETECTED_SERVICES[@]}"
if [[ $? == 0 ]]; then
HMENU+=("${collection}" "${description}" "ON")
#in case we're not in interactive mode, assume defaults
# in case we're not in interactive mode, assume defaults
COLLECTION_TO_INSTALL+=(${collection})
else
if [[ ${collection} == "linux" ]]; then
HMENU+=("${collection}" "${description}" "ON")
#in case we're not in interactive mode, assume defaults
# in case we're not in interactive mode, assume defaults
COLLECTION_TO_INSTALL+=(${collection})
else
HMENU+=("${collection}" "${description}" "OFF")
@ -272,10 +272,10 @@ install_collection() {
for collection in "${COLLECTION_TO_INSTALL[@]}"; do
log_info "Installing collection '${collection}'"
${CSCLI_BIN_INSTALLED} collections install "${collection}" > /dev/null 2>&1 || log_err "fail to install collection ${collection}"
${CSCLI_BIN_INSTALLED} collections install "${collection}" --error
done
${CSCLI_BIN_INSTALLED} parsers install "crowdsecurity/whitelists" > /dev/null 2>&1 || log_err "fail to install collection crowdsec/whitelists"
${CSCLI_BIN_INSTALLED} parsers install "crowdsecurity/whitelists" --error
if [[ ${SILENT} == "false" ]]; then
whiptail --msgbox "Out of safety, I installed a parser called 'crowdsecurity/whitelists'. This one will prevent private IP addresses from being banned, feel free to remove it any time." 20 50
fi
@ -285,14 +285,14 @@ install_collection() {
fi
}
#$1 is the service name, $... is the list of candidate logs (from find_logs_for)
# $1 is the service name, $... is the list of candidate logs (from find_logs_for)
genyamllog() {
local service="${1}"
shift
local files=("${@}")
echo "#Generated acquisition file - wizard.sh (service: ${service}) / files : ${files[@]}" >> ${TMP_ACQUIS_FILE}
echo "filenames:" >> ${TMP_ACQUIS_FILE}
for fd in ${files[@]}; do
echo " - ${fd}" >> ${TMP_ACQUIS_FILE}
@ -306,9 +306,9 @@ genyamllog() {
genyamljournal() {
local service="${1}"
shift
echo "#Generated acquisition file - wizard.sh (service: ${service}) / files : ${files[@]}" >> ${TMP_ACQUIS_FILE}
echo "journalctl_filter:" >> ${TMP_ACQUIS_FILE}
echo " - _SYSTEMD_UNIT="${service}".service" >> ${TMP_ACQUIS_FILE}
echo "labels:" >> ${TMP_ACQUIS_FILE}
@ -318,7 +318,7 @@ genyamljournal() {
}
genacquisition() {
if skip_tmp_acquis; then
if skip_tmp_acquis; then
TMP_ACQUIS_FILE="${ACQUIS_TARGET}"
ACQUIS_FILE_MSG="acquisition file generated to: ${TMP_ACQUIS_FILE}"
else
@ -336,7 +336,7 @@ genacquisition() {
log_info "using journald for '${PSVG}'"
genyamljournal ${PSVG}
fi;
done
done
}
detect_cs_install () {
@ -371,7 +371,7 @@ check_cs_version () {
fi
elif [[ $NEW_MINOR_VERSION -gt $CURRENT_MINOR_VERSION ]] ; then
log_warn "new version ($NEW_CS_VERSION) is a minor upgrade !"
if [[ $ACTION != "upgrade" ]] ; then
if [[ $ACTION != "upgrade" ]] ; then
if [[ ${FORCE_MODE} == "false" ]]; then
echo ""
echo "We recommend to upgrade with : sudo ./wizard.sh --upgrade "
@ -383,7 +383,7 @@ check_cs_version () {
fi
elif [[ $NEW_PATCH_VERSION -gt $CURRENT_PATCH_VERSION ]] ; then
log_warn "new version ($NEW_CS_VERSION) is a patch !"
if [[ $ACTION != "binupgrade" ]] ; then
if [[ $ACTION != "binupgrade" ]] ; then
if [[ ${FORCE_MODE} == "false" ]]; then
echo ""
echo "We recommend to upgrade binaries only : sudo ./wizard.sh --binupgrade "
@ -406,7 +406,7 @@ check_cs_version () {
fi
}
#install crowdsec and cscli
# install crowdsec and cscli
install_crowdsec() {
mkdir -p "${CROWDSEC_DATA_DIR}"
(cd config && find patterns -type f -exec install -Dm 644 "{}" "${CROWDSEC_CONFIG_PATH}/{}" \; && cd ../) || exit
@ -418,7 +418,7 @@ install_crowdsec() {
mkdir -p "${CROWDSEC_CONFIG_PATH}/appsec-rules" || exit
mkdir -p "${CROWDSEC_CONSOLE_DIR}" || exit
#tmp
# tmp
mkdir -p /tmp/data
mkdir -p /etc/crowdsec/hub/
install -v -m 600 -D "./config/${CLIENT_SECRETS}" "${CROWDSEC_CONFIG_PATH}" 1> /dev/null || exit
@ -490,7 +490,7 @@ install_bins() {
install -v -m 755 -D "${CSCLI_BIN}" "${CSCLI_BIN_INSTALLED}" 1> /dev/null || exit
which systemctl && systemctl is-active --quiet crowdsec
if [ $? -eq 0 ]; then
systemctl stop crowdsec
systemctl stop crowdsec
fi
install_plugins
symlink_bins
@ -508,7 +508,7 @@ symlink_bins() {
delete_bins() {
log_info "Removing crowdsec binaries"
rm -f ${CROWDSEC_BIN_INSTALLED}
rm -f ${CSCLI_BIN_INSTALLED}
rm -f ${CSCLI_BIN_INSTALLED}
}
delete_plugins() {
@ -535,7 +535,7 @@ install_plugins(){
}
check_running_bouncers() {
#when uninstalling, check if user still has bouncers
# when uninstalling, check if user still has bouncers
BOUNCERS_COUNT=$(${CSCLI_BIN} bouncers list -o=raw | tail -n +2 | wc -l)
if [[ ${BOUNCERS_COUNT} -gt 0 ]] ; then
if [[ ${FORCE_MODE} == "false" ]]; then
@ -646,7 +646,7 @@ main() {
then
return
fi
if [[ "$1" == "uninstall" ]];
then
if ! [ $(id -u) = 0 ]; then
@ -685,11 +685,11 @@ main() {
log_info "installing crowdsec"
install_crowdsec
log_dbg "configuring ${CSCLI_BIN_INSTALLED}"
${CSCLI_BIN_INSTALLED} hub update > /dev/null 2>&1 || (log_err "fail to update crowdsec hub. exiting" && exit 1)
${CSCLI_BIN_INSTALLED} hub update --error || (log_err "fail to update crowdsec hub. exiting" && exit 1)
# detect running services
detect_services
if ! [ ${#DETECTED_SERVICES[@]} -gt 0 ] ; then
if ! [ ${#DETECTED_SERVICES[@]} -gt 0 ] ; then
log_err "No detected or selected services, stopping."
exit 1
fi;
@ -711,11 +711,11 @@ main() {
# api register
${CSCLI_BIN_INSTALLED} machines add --force "$(cat /etc/machine-id)" -a -f "${CROWDSEC_CONFIG_PATH}/${CLIENT_SECRETS}" || log_fatal "unable to add machine to the local API"
log_dbg "Crowdsec LAPI registered"
log_dbg "Crowdsec LAPI registered"
${CSCLI_BIN_INSTALLED} capi register || log_fatal "unable to register to the Central API"
log_dbg "Crowdsec CAPI registered"
log_dbg "Crowdsec CAPI registered"
systemctl enable -q crowdsec >/dev/null || log_fatal "unable to enable crowdsec"
systemctl start crowdsec >/dev/null || log_fatal "unable to start crowdsec"
log_info "enabling and starting crowdsec daemon"
@ -729,7 +729,7 @@ main() {
rm -f "${TMP_ACQUIS_FILE}"
fi
detect_services
if [[ ${DETECTED_SERVICES} == "" ]] ; then
if [[ ${DETECTED_SERVICES} == "" ]] ; then
log_err "No detected or selected services, stopping."
exit
fi;
@ -757,7 +757,7 @@ usage() {
echo " ./wizard.sh --docker-mode Will install crowdsec without systemd and generate random machine-id"
echo " ./wizard.sh -n|--noop Do nothing"
exit 0
exit 0
}
if [[ $# -eq 0 ]]; then
@ -770,15 +770,15 @@ do
case ${key} in
--uninstall)
ACTION="uninstall"
shift #past argument
shift # past argument
;;
--binupgrade)
ACTION="binupgrade"
shift #past argument
shift # past argument
;;
--upgrade)
ACTION="upgrade"
shift #past argument
shift # past argument
;;
-i|--install)
ACTION="install"
@ -813,11 +813,11 @@ do
-f|--force)
FORCE_MODE="true"
shift
;;
;;
-v|--verbose)
DEBUG_MODE="true"
shift
;;
;;
-h|--help)
usage
exit 0