package s3acquisition import ( "context" "fmt" "strings" "sync/atomic" "testing" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3iface" "github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/sqs/sqsiface" "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "gopkg.in/tomb.v2" ) func TestBadConfiguration(t *testing.T) { tests := []struct { name string config string expectedErr string }{ { name: "no bucket", config: ` source: s3 `, expectedErr: "bucket_name is required", }, { name: "invalid polling method", config: ` source: s3 bucket_name: foobar polling_method: foobar `, expectedErr: "invalid polling method foobar", }, { name: "no sqs name", config: ` source: s3 bucket_name: foobar polling_method: sqs `, expectedErr: "sqs_name is required when using sqs polling method", }, { name: "both bucket and sqs", config: ` source: s3 bucket_name: foobar polling_method: sqs sqs_name: foobar `, expectedErr: "bucket_name and sqs_name are mutually exclusive", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := S3Source{} err := f.Configure([]byte(test.config), nil) if err == nil { t.Fatalf("expected error, got none") } if err.Error() != test.expectedErr { t.Fatalf("expected error %s, got %s", test.expectedErr, err.Error()) } }) } } func TestGoodConfiguration(t *testing.T) { tests := []struct { name string config string }{ { name: "basic", config: ` source: s3 bucket_name: foobar `, }, { name: "polling method", config: ` source: s3 polling_method: sqs sqs_name: foobar `, }, { name: "list method", config: ` source: s3 bucket_name: foobar polling_method: list `, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := S3Source{} logger := log.NewEntry(log.New()) err := f.Configure([]byte(test.config), logger) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } }) } } type mockS3Client struct { s3iface.S3API } // We add one hour to trick the listing goroutine into thinking the files are new var mockListOutput map[string][]*s3.Object = map[string][]*s3.Object{ "bucket_no_prefix": { { Key: aws.String("foo.log"), LastModified: aws.Time(time.Now().Add(time.Hour)), }, }, "bucket_with_prefix": { { Key: aws.String("prefix/foo.log"), LastModified: aws.Time(time.Now().Add(time.Hour)), }, { Key: aws.String("prefix/bar.log"), LastModified: aws.Time(time.Now().Add(time.Hour)), }, }, } func (m mockS3Client) ListObjectsV2WithContext(ctx context.Context, input *s3.ListObjectsV2Input, options ...request.Option) (*s3.ListObjectsV2Output, error) { log.Infof("returning mock list output for %s, %v", *input.Bucket, mockListOutput[*input.Bucket]) return &s3.ListObjectsV2Output{ Contents: mockListOutput[*input.Bucket], }, nil } func (m mockS3Client) GetObjectWithContext(ctx context.Context, input *s3.GetObjectInput, options ...request.Option) (*s3.GetObjectOutput, error) { r := strings.NewReader("foo\nbar") return &s3.GetObjectOutput{ Body: aws.ReadSeekCloser(r), }, nil } type mockSQSClient struct { sqsiface.SQSAPI counter *int32 } func (msqs mockSQSClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) { if atomic.LoadInt32(msqs.counter) == 1 { return &sqs.ReceiveMessageOutput{}, nil } atomic.AddInt32(msqs.counter, 1) return &sqs.ReceiveMessageOutput{ Messages: []*sqs.Message{ { Body: aws.String(` {"version":"0","id":"af1ce7ea-bdb4-5bb7-3af2-c6cb32f9aac9","detail-type":"Object Created","source":"aws.s3","account":"1234","time":"2023-03-17T07:45:04Z","region":"eu-west-1","resources":["arn:aws:s3:::my_bucket"],"detail":{"version":"0","bucket":{"name":"my_bucket"},"object":{"key":"foo.log","size":663,"etag":"f2d5268a0776d6cdd6e14fcfba96d1cd","sequencer":"0064141A8022966874"},"request-id":"MBWX2P6FWA3S1YH5","requester":"156460612806","source-ip-address":"42.42.42.42","reason":"PutObject"}}`), }, }, }, nil } func (msqs mockSQSClient) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) { return &sqs.DeleteMessageOutput{}, nil } type mockSQSClientNotif struct { sqsiface.SQSAPI counter *int32 } func (msqs mockSQSClientNotif) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) { if atomic.LoadInt32(msqs.counter) == 1 { return &sqs.ReceiveMessageOutput{}, nil } atomic.AddInt32(msqs.counter, 1) return &sqs.ReceiveMessageOutput{ Messages: []*sqs.Message{ { Body: aws.String(` {"Records":[{"eventVersion":"2.1","eventSource":"aws:s3","awsRegion":"eu-west-1","eventTime":"2023-03-20T19:30:02.536Z","eventName":"ObjectCreated:Put","userIdentity":{"principalId":"AWS:XXXXX"},"requestParameters":{"sourceIPAddress":"42.42.42.42"},"responseElements":{"x-amz-request-id":"FM0TAV2WE5AXXW42","x-amz-id-2":"LCfQt1aSBtD1G5wdXjB5ANdPxLEXJxA89Ev+/rRAsCGFNJGI/1+HMlKI59S92lqvzfViWh7B74leGKWB8/nNbsbKbK7WXKz2"},"s3":{"s3SchemaVersion":"1.0","configurationId":"test-acquis","bucket":{"name":"my_bucket","ownerIdentity":{"principalId":"A1F2PSER1FB8MY"},"arn":"arn:aws:s3:::my_bucket"},"object":{"key":"foo.log","size":3097,"eTag":"ab6889744611c77991cbc6ca12d1ddc7","sequencer":"006418B43A76BC0257"}}}]}`), }, }, }, nil } func (msqs mockSQSClientNotif) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) { return &sqs.DeleteMessageOutput{}, nil } func TestDSNAcquis(t *testing.T) { tests := []struct { name string dsn string expectedBucketName string expectedPrefix string expectedCount int }{ { name: "basic", dsn: "s3://bucket_no_prefix/foo.log", expectedBucketName: "bucket_no_prefix", expectedPrefix: "", expectedCount: 2, }, { name: "with prefix", dsn: "s3://bucket_with_prefix/prefix/", expectedBucketName: "bucket_with_prefix", expectedPrefix: "prefix/", expectedCount: 4, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { linesRead := 0 f := S3Source{} logger := log.NewEntry(log.New()) err := f.ConfigureByDSN(test.dsn, map[string]string{"foo": "bar"}, logger, "") if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } assert.Equal(t, test.expectedBucketName, f.Config.BucketName) assert.Equal(t, test.expectedPrefix, f.Config.Prefix) out := make(chan types.Event) done := make(chan bool) go func() { for { select { case s := <-out: fmt.Printf("got line %s\n", s.Line.Raw) linesRead++ case <-done: return } } }() f.s3Client = mockS3Client{} tmb := tomb.Tomb{} err = f.OneShotAcquisition(out, &tmb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } time.Sleep(2 * time.Second) done <- true assert.Equal(t, test.expectedCount, linesRead) }) } } func TestListPolling(t *testing.T) { tests := []struct { name string config string expectedCount int }{ { name: "basic", config: ` source: s3 bucket_name: bucket_no_prefix polling_method: list polling_interval: 1 `, expectedCount: 2, }, { name: "with prefix", config: ` source: s3 bucket_name: bucket_with_prefix polling_method: list polling_interval: 1 prefix: foo/ `, expectedCount: 4, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { linesRead := 0 f := S3Source{} logger := log.NewEntry(log.New()) logger.Logger.SetLevel(log.TraceLevel) err := f.Configure([]byte(test.config), logger) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } if f.Config.PollingMethod != PollMethodList { t.Fatalf("expected list polling, got %s", f.Config.PollingMethod) } f.s3Client = mockS3Client{} out := make(chan types.Event) tb := tomb.Tomb{} go func() { for { select { case s := <-out: fmt.Printf("got line %s\n", s.Line.Raw) linesRead++ case <-tb.Dying(): return } } }() err = f.StreamingAcquisition(out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } time.Sleep(2 * time.Second) tb.Kill(nil) err = tb.Wait() if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } assert.Equal(t, test.expectedCount, linesRead) }) } } func TestSQSPoll(t *testing.T) { tests := []struct { name string config string notifType string expectedCount int }{ { name: "eventbridge", config: ` source: s3 polling_method: sqs sqs_name: test `, expectedCount: 2, notifType: "eventbridge", }, { name: "notification", config: ` source: s3 polling_method: sqs sqs_name: test `, expectedCount: 2, notifType: "notification", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { linesRead := 0 f := S3Source{} logger := log.NewEntry(log.New()) err := f.Configure([]byte(test.config), logger) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } if f.Config.PollingMethod != PollMethodSQS { t.Fatalf("expected sqs polling, got %s", f.Config.PollingMethod) } counter := int32(0) f.s3Client = mockS3Client{} if test.notifType == "eventbridge" { f.sqsClient = mockSQSClient{counter: &counter} } else { f.sqsClient = mockSQSClientNotif{counter: &counter} } out := make(chan types.Event) tb := tomb.Tomb{} go func() { for { select { case s := <-out: fmt.Printf("got line %s\n", s.Line.Raw) linesRead++ case <-tb.Dying(): return } } }() err = f.StreamingAcquisition(out, &tb) if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } time.Sleep(2 * time.Second) tb.Kill(nil) err = tb.Wait() if err != nil { t.Fatalf("unexpected error: %s", err.Error()) } assert.Equal(t, test.expectedCount, linesRead) }) } }