diff --git a/cmd/crowdsec-cli/explain.go b/cmd/crowdsec-cli/explain.go index d3b92c2b5..3e7f48fa0 100644 --- a/cmd/crowdsec-cli/explain.go +++ b/cmd/crowdsec-cli/explain.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "errors" "fmt" "io" "os" @@ -196,7 +197,7 @@ func (cli cliExplain) run(cmd *cobra.Command, args []string) error { errCount := 0 for { input, err := reader.ReadBytes('\n') - if err != nil && err == io.EOF { + if err != nil && errors.Is(err, io.EOF) { break } if len(input) > 1 { diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 0f3b9c4a0..1fd65dc38 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -1,6 +1,7 @@ package cache import ( + "errors" "time" "github.com/bluele/gcache" @@ -104,7 +105,7 @@ func GetKey(cacheName string, key string) (string, error) { if name == cacheName { if value, err := Caches[i].Get(key); err != nil { //do not warn or log if key not found - if err == gcache.KeyNotFoundError { + if errors.Is(err, gcache.KeyNotFoundError) { return "", nil } CacheConfig[i].Logger.Warningf("While getting key %s in cache %s: %s", key, cacheName, err) diff --git a/pkg/cticlient/client.go b/pkg/cticlient/client.go index 16876026a..4df4d65a6 100644 --- a/pkg/cticlient/client.go +++ b/pkg/cticlient/client.go @@ -71,7 +71,7 @@ func (c *CrowdsecCTIClient) doRequest(method string, endpoint string, params map func (c *CrowdsecCTIClient) GetIPInfo(ip string) (*SmokeItem, error) { body, err := c.doRequest(http.MethodGet, smokeEndpoint+"/"+ip, nil) if err != nil { - if err == ErrNotFound { + if errors.Is(err, ErrNotFound) { return &SmokeItem{}, nil } return nil, err diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index 0ae63f374..83af76464 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -742,7 +742,7 @@ func (c *Client) CreateAlert(machineID string, alertList []*models.Alert) ([]str if machineID != "" { owner, err = c.QueryMachineByID(machineID) if err != nil { - if errors.Cause(err) != UserNotExists { + if !errors.Is(err, UserNotExists) { return nil, fmt.Errorf("machine '%s': %w", machineID, err) } diff --git a/pkg/exprhelpers/crowdsec_cti.go b/pkg/exprhelpers/crowdsec_cti.go index dc67815d8..59a239722 100644 --- a/pkg/exprhelpers/crowdsec_cti.go +++ b/pkg/exprhelpers/crowdsec_cti.go @@ -104,12 +104,12 @@ func CrowdsecCTI(params ...any) (any, error) { ctiResp, err := ctiClient.GetIPInfo(ip) ctiClient.Logger.Debugf("request for %s took %v", ip, time.Since(before)) if err != nil { - switch err { - case cticlient.ErrUnauthorized: + switch { + case errors.Is(err, cticlient.ErrUnauthorized): CTIApiEnabled = false ctiClient.Logger.Errorf("Invalid API key provided, disabling CTI API") return &cticlient.SmokeItem{}, cticlient.ErrUnauthorized - case cticlient.ErrLimit: + case errors.Is(err, cticlient.ErrLimit): CTIBackOffUntil = time.Now().Add(CTIBackOffDuration) ctiClient.Logger.Errorf("CTI API is throttled, will try again in %s", CTIBackOffDuration) return &cticlient.SmokeItem{}, cticlient.ErrLimit diff --git a/pkg/exprhelpers/crowdsec_cti_test.go b/pkg/exprhelpers/crowdsec_cti_test.go index 84cd3347b..80ccadba4 100644 --- a/pkg/exprhelpers/crowdsec_cti_test.go +++ b/pkg/exprhelpers/crowdsec_cti_test.go @@ -3,6 +3,7 @@ package exprhelpers import ( "bytes" "encoding/json" + "errors" "io" "net/http" "strings" @@ -108,7 +109,7 @@ func smokeHandler(req *http.Request) *http.Response { func TestNillClient(t *testing.T) { defer ShutdownCrowdsecCTI() - if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); err != cticlient.ErrDisabled { + if err := InitCrowdsecCTI(ptr.Of(""), nil, nil, nil); !errors.Is(err, cticlient.ErrDisabled) { t.Fatalf("failed to init CTI : %s", err) } item, err := CrowdsecCTI("1.2.3.4") diff --git a/pkg/longpollclient/client.go b/pkg/longpollclient/client.go index a35afa813..e93870a28 100644 --- a/pkg/longpollclient/client.go +++ b/pkg/longpollclient/client.go @@ -2,6 +2,7 @@ package longpollclient import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -112,7 +113,7 @@ func (c *LongPollClient) poll() error { var pollResp pollResponse err = decoder.Decode(&pollResp) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { logger.Debugf("server closed connection") return nil } @@ -158,7 +159,7 @@ func (c *LongPollClient) pollEvents() error { err := c.poll() if err != nil { c.logger.Errorf("failed to poll: %s", err) - if err == errUnauthorized { + if errors.Is(err, errUnauthorized) { c.t.Kill(err) close(c.c) return err @@ -198,7 +199,7 @@ func (c *LongPollClient) PullOnce(since time.Time) ([]Event, error) { var pollResp pollResponse err = decoder.Decode(&pollResp) if err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { c.logger.Debugf("server closed connection") break } diff --git a/pkg/parser/enrich_date_test.go b/pkg/parser/enrich_date_test.go index b794676cc..084ded525 100644 --- a/pkg/parser/enrich_date_test.go +++ b/pkg/parser/enrich_date_test.go @@ -4,34 +4,33 @@ import ( "testing" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" - "github.com/crowdsecurity/go-cs-lib/ptr" + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/types" ) func TestDateParse(t *testing.T) { tests := []struct { - name string - evt types.Event - expected_err *error - expected_strTime *string + name string + evt types.Event + expectedErr string + expected string }{ { name: "RFC3339", evt: types.Event{ StrTime: "2019-10-12T07:20:50.52Z", }, - expected_err: nil, - expected_strTime: ptr.Of("2019-10-12T07:20:50.52Z"), + expected: "2019-10-12T07:20:50.52Z", }, { name: "02/Jan/2006:15:04:05 -0700", evt: types.Event{ StrTime: "02/Jan/2006:15:04:05 -0700", }, - expected_err: nil, - expected_strTime: ptr.Of("2006-01-02T15:04:05-07:00"), + expected: "2006-01-02T15:04:05-07:00", }, { name: "Dec 17 08:17:43", @@ -39,8 +38,7 @@ func TestDateParse(t *testing.T) { StrTime: "2011 X 17 zz 08X17X43 oneone Dec", StrTimeFormat: "2006 X 2 zz 15X04X05 oneone Jan", }, - expected_err: nil, - expected_strTime: ptr.Of("2011-12-17T08:17:43Z"), + expected: "2011-12-17T08:17:43Z", }, } @@ -51,19 +49,11 @@ func TestDateParse(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { strTime, err := ParseDate(tt.evt.StrTime, &tt.evt, nil, logger) - if tt.expected_err != nil { - if err != *tt.expected_err { - t.Errorf("%s: expected error %v, got %v", tt.name, tt.expected_err, err) - } - } else if err != nil { - t.Errorf("%s: expected no error, got %v", tt.name, err) - } - if err != nil { + cstest.RequireErrorContains(t, err, tt.expectedErr) + if tt.expectedErr != "" { return } - if tt.expected_strTime != nil && strTime["MarshaledTime"] != *tt.expected_strTime { - t.Errorf("expected strTime %s, got %s", *tt.expected_strTime, strTime["MarshaledTime"]) - } + assert.Equal(t, tt.expected, strTime["MarshaledTime"]) }) } }