From dc38e5ac00a66160f56952763620b06f8887176a Mon Sep 17 00:00:00 2001 From: blotus Date: Tue, 21 Mar 2023 13:54:52 +0100 Subject: [PATCH] S3 acquisition datasource (#2130) --- go.mod | 1 + go.sum | 2 + pkg/acquisition/acquisition.go | 2 + pkg/acquisition/modules/s3/s3.go | 647 ++++++++++++++++++++++++++ pkg/acquisition/modules/s3/s3_test.go | 429 +++++++++++++++++ 5 files changed, 1081 insertions(+) create mode 100644 pkg/acquisition/modules/s3/s3.go create mode 100644 pkg/acquisition/modules/s3/s3_test.go diff --git a/go.mod b/go.mod index 00a6098cc..8b94ef948 100644 --- a/go.mod +++ b/go.mod @@ -98,6 +98,7 @@ require ( github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26 // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef // indirect + github.com/aws/aws-lambda-go v1.38.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/containerd/containerd v1.6.18 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect diff --git a/go.sum b/go.sum index baaa391fb..3b2bc3b13 100644 --- a/go.sum +++ b/go.sum @@ -107,6 +107,8 @@ github.com/asaskevich/govalidator v0.0.0-20200108200545-475eaeb16496/go.mod h1:o github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535/go.mod h1:oGkLhpf+kjZl6xBf758TQhh5XrAeiJv/7FRz/2spLIg= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef h1:46PFijGLmAjMPwCCCo7Jf0W6f9slllCkkv7vyc1yOSg= github.com/asaskevich/govalidator v0.0.0-20200907205600-7a23bdc65eef/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= +github.com/aws/aws-lambda-go v1.38.0 h1:4CUdxGzvuQp0o8Zh7KtupB9XvCiiY8yKqJtzco+gsDw= +github.com/aws/aws-lambda-go v1.38.0/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8qK17ewzbQMM= github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= github.com/aws/aws-sdk-go v1.42.25 h1:BbdvHAi+t9LRiaYUyd53noq9jcaAcfzOhSVbKfr6Avs= github.com/aws/aws-sdk-go v1.42.25/go.mod h1:gyRszuZ/icHmHAVE4gc/r+cfCmhA1AD+vqfWbgI+eHs= diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go index 113088e7c..5ffa2d441 100644 --- a/pkg/acquisition/acquisition.go +++ b/pkg/acquisition/acquisition.go @@ -20,6 +20,7 @@ import ( kafkaacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kafka" kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis" k8sauditacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kubernetesaudit" + s3acquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/s3" syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog" wineventlogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/wineventlog" @@ -52,6 +53,7 @@ var AcquisitionSources = map[string]func() DataSource{ "wineventlog": func() DataSource { return &wineventlogacquisition.WinEventLogSource{} }, "kafka": func() DataSource { return &kafkaacquisition.KafkaSource{} }, "k8s_audit": func() DataSource { return &k8sauditacquisition.KubernetesAuditSource{} }, + "s3": func() DataSource { return &s3acquisition.S3Source{} }, } func GetDataSourceIface(dataSourceType string) DataSource { diff --git a/pkg/acquisition/modules/s3/s3.go b/pkg/acquisition/modules/s3/s3.go new file mode 100644 index 000000000..d30bc84b6 --- /dev/null +++ b/pkg/acquisition/modules/s3/s3.go @@ -0,0 +1,647 @@ +package s3acquisition + +import ( + "bufio" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "net/url" + "sort" + "strings" + "time" + + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" + "gopkg.in/tomb.v2" + "gopkg.in/yaml.v2" +) + +type S3Configuration struct { + configuration.DataSourceCommonCfg `yaml:",inline"` + AwsProfile *string `yaml:"aws_profile"` + AwsRegion string `yaml:"aws_region"` + AwsEndpoint string `yaml:"aws_endpoint"` + BucketName string `yaml:"bucket_name"` + Prefix string `yaml:"prefix"` + Key string `yaml:"-"` //Only for DSN acquisition + PollingMethod string `yaml:"polling_method"` + PollingInterval int `yaml:"polling_interval"` + SQSName string `yaml:"sqs_name"` + SQSFormat string `yaml:"sqs_format"` +} + +type S3Source struct { + Config S3Configuration + logger *log.Entry + s3Client s3iface.S3API + sqsClient sqsiface.SQSAPI + readerChan chan S3Object + t *tomb.Tomb + out chan types.Event + ctx aws.Context + cancel context.CancelFunc +} + +type S3Object struct { + Key string + Bucket string +} + +// For some reason, the aws sdk doesn't have a struct for this +// The one aws-lamdbda-go/events is only intended when using S3 Notification without event bridge +type S3Event struct { + Version string `json:"version"` + Id string `json:"id"` + DetailType string `json:"detail-type"` + Source string `json:"source"` + Account string `json:"account"` + Time string `json:"time"` + Region string `json:"region"` + Resources []string `json:"resources"` + Detail struct { + Version string `json:"version"` + RequestId string `json:"request-id"` + Requester string `json:"requester"` + Reason string `json:"reason"` + SourceIpAddress string `json:"source-ip-address"` + Bucket struct { + Name string `json:"name"` + } `json:"bucket"` + Object struct { + Key string `json:"key"` + Size int `json:"size"` + Etag string `json:"etag"` + Sequencer string `json:"sequencer"` + } `json:"object"` + } `json:"detail"` +} + +const PollMethodList = "list" +const PollMethodSQS = "sqs" +const SQSFormatEventBridge = "eventbridge" +const SQSFormatS3Notification = "s3notification" + +var linesRead = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_s3_hits_total", + Help: "Number of events read per bucket.", + }, + []string{"bucket"}, +) + +var objectsRead = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_s3_objects_total", + Help: "Number of objects read per bucket.", + }, + []string{"bucket"}, +) + +var sqsMessagesReceived = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_s3_sqs_messages_total", + Help: "Number of SQS messages received per queue.", + }, + []string{"queue"}, +) + +func (s *S3Source) newS3Client() error { + options := session.Options{ + SharedConfigState: session.SharedConfigEnable, + } + if s.Config.AwsProfile != nil { + options.Profile = *s.Config.AwsProfile + } + + sess, err := session.NewSessionWithOptions(options) + + if err != nil { + return fmt.Errorf("failed to create aws session: %w", err) + } + + config := aws.NewConfig() + if s.Config.AwsRegion != "" { + config = config.WithRegion(s.Config.AwsRegion) + } + if s.Config.AwsEndpoint != "" { + config = config.WithEndpoint(s.Config.AwsEndpoint) + } + s.s3Client = s3.New(sess, config) + if s.s3Client == nil { + return fmt.Errorf("failed to create S3 client") + } + return nil +} + +func (s *S3Source) newSQSClient() error { + var sess *session.Session + + if s.Config.AwsProfile != nil { + sess = session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + Profile: *s.Config.AwsProfile, + })) + } else { + sess = session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + } + + if sess == nil { + return fmt.Errorf("failed to create aws session") + } + config := aws.NewConfig() + if s.Config.AwsRegion != "" { + config = config.WithRegion(s.Config.AwsRegion) + } + if s.Config.AwsEndpoint != "" { + config = config.WithEndpoint(s.Config.AwsEndpoint) + } + s.sqsClient = sqs.New(sess, config) + if s.sqsClient == nil { + return fmt.Errorf("failed to create SQS client") + } + return nil +} + +func (s *S3Source) readManager() { + logger := s.logger.WithField("method", "readManager") + for { + select { + case <-s.t.Dying(): + logger.Infof("Shutting down S3 read manager") + s.cancel() + return + case s3Object := <-s.readerChan: + logger.Debugf("Reading file %s/%s", s3Object.Bucket, s3Object.Key) + err := s.readFile(s3Object.Bucket, s3Object.Key) + if err != nil { + logger.Errorf("Error while reading file: %s", err) + } + } + } +} + +func (s *S3Source) getBucketContent() ([]*s3.Object, error) { + logger := s.logger.WithField("method", "getBucketContent") + logger.Debugf("Getting bucket content for %s", s.Config.BucketName) + bucketObjects := make([]*s3.Object, 0) + var continuationToken *string = nil + for { + out, err := s.s3Client.ListObjectsV2WithContext(s.ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(s.Config.BucketName), + Prefix: aws.String(s.Config.Prefix), + ContinuationToken: continuationToken, + }) + if err != nil { + logger.Errorf("Error while listing bucket content: %s", err) + return nil, err + } + bucketObjects = append(bucketObjects, out.Contents...) + if out.NextContinuationToken == nil { + break + } + continuationToken = out.NextContinuationToken + } + sort.Slice(bucketObjects, func(i, j int) bool { + return bucketObjects[i].LastModified.Before(*bucketObjects[j].LastModified) + }) + return bucketObjects, nil +} + +func (s *S3Source) listPoll() error { + logger := s.logger.WithField("method", "listPoll") + ticker := time.NewTicker(time.Duration(s.Config.PollingInterval) * time.Second) + lastObjectDate := time.Now() + defer ticker.Stop() + + for { + select { + case <-s.t.Dying(): + logger.Infof("Shutting down list poller") + s.cancel() + return nil + case <-ticker.C: + newObject := false + bucketObjects, err := s.getBucketContent() + if err != nil { + logger.Errorf("Error while getting bucket content: %s", err) + continue + } + if bucketObjects == nil { + continue + } + for i := len(bucketObjects) - 1; i >= 0; i-- { + if bucketObjects[i].LastModified.After(lastObjectDate) { + newObject = true + logger.Debugf("Found new object %s", *bucketObjects[i].Key) + s.readerChan <- S3Object{ + Bucket: s.Config.BucketName, + Key: *bucketObjects[i].Key, + } + } else { + break + } + } + if newObject { + lastObjectDate = *bucketObjects[len(bucketObjects)-1].LastModified + } + } + } +} + +func extractBucketAndPrefixFromEventBridge(message *string) (string, string, error) { + eventBody := S3Event{} + err := json.Unmarshal([]byte(*message), &eventBody) + if err != nil { + return "", "", err + } + if eventBody.Detail.Bucket.Name != "" { + return eventBody.Detail.Bucket.Name, eventBody.Detail.Object.Key, nil + } + return "", "", fmt.Errorf("invalid event body for event bridge format") +} + +func extractBucketAndPrefixFromS3Notif(message *string) (string, string, error) { + s3notifBody := events.S3Event{} + err := json.Unmarshal([]byte(*message), &s3notifBody) + if err != nil { + return "", "", err + } + if len(s3notifBody.Records) == 0 { + return "", "", fmt.Errorf("no records found in S3 notification") + } + if !strings.HasPrefix(s3notifBody.Records[0].EventName, "ObjectCreated:") { + return "", "", fmt.Errorf("event %s is not supported", s3notifBody.Records[0].EventName) + } + return s3notifBody.Records[0].S3.Bucket.Name, s3notifBody.Records[0].S3.Object.Key, nil +} + +func (s *S3Source) extractBucketAndPrefix(message *string) (string, string, error) { + if s.Config.SQSFormat == SQSFormatEventBridge { + bucket, key, err := extractBucketAndPrefixFromEventBridge(message) + if err != nil { + return "", "", err + } + return bucket, key, nil + } else if s.Config.SQSFormat == SQSFormatS3Notification { + bucket, key, err := extractBucketAndPrefixFromS3Notif(message) + if err != nil { + return "", "", err + } + return bucket, key, nil + } else { + bucket, key, err := extractBucketAndPrefixFromEventBridge(message) + if err == nil { + s.Config.SQSFormat = SQSFormatEventBridge + return bucket, key, nil + } + bucket, key, err = extractBucketAndPrefixFromS3Notif(message) + if err == nil { + s.Config.SQSFormat = SQSFormatS3Notification + return bucket, key, nil + } + return "", "", fmt.Errorf("SQS message format not supported") + } +} + +func (s *S3Source) sqsPoll() error { + logger := s.logger.WithField("method", "sqsPoll") + for { + select { + case <-s.t.Dying(): + logger.Infof("Shutting down SQS poller") + s.cancel() + return nil + default: + logger.Trace("Polling SQS queue") + out, err := s.sqsClient.ReceiveMessageWithContext(s.ctx, &sqs.ReceiveMessageInput{ + QueueUrl: aws.String(s.Config.SQSName), + MaxNumberOfMessages: aws.Int64(10), + WaitTimeSeconds: aws.Int64(20), //Probably no need to make it configurable ? + }) + if err != nil { + logger.Errorf("Error while polling SQS: %s", err) + continue + } + logger.Tracef("SQS output: %v", out) + logger.Debugf("Received %d messages from SQS", len(out.Messages)) + for _, message := range out.Messages { + sqsMessagesReceived.WithLabelValues(s.Config.SQSName).Inc() + bucket, key, err := s.extractBucketAndPrefix(message.Body) + if err != nil { + logger.Errorf("Error while parsing SQS message: %s", err) + //Always delete the message to avoid infinite loop + _, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ + QueueUrl: aws.String(s.Config.SQSName), + ReceiptHandle: message.ReceiptHandle, + }) + if err != nil { + logger.Errorf("Error while deleting SQS message: %s", err) + } + continue + } + logger.Debugf("Received SQS message for object %s/%s", bucket, key) + s.readerChan <- S3Object{Key: key, Bucket: bucket} + _, err = s.sqsClient.DeleteMessage(&sqs.DeleteMessageInput{ + QueueUrl: aws.String(s.Config.SQSName), + ReceiptHandle: message.ReceiptHandle, + }) + if err != nil { + logger.Errorf("Error while deleting SQS message: %s", err) + } + logger.Debugf("Deleted SQS message for object %s/%s", bucket, key) + } + } + } +} + +func (s *S3Source) readFile(bucket string, key string) error { + //TODO: Handle SSE-C + var scanner *bufio.Scanner + + logger := s.logger.WithFields(logrus.Fields{ + "method": "readFile", + "bucket": bucket, + "key": key, + }) + + output, err := s.s3Client.GetObjectWithContext(s.ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + if err != nil { + return fmt.Errorf("failed to get object %s/%s: %w", bucket, key, err) + } + defer output.Body.Close() + if strings.HasSuffix(key, ".gz") { + gzReader, err := gzip.NewReader(output.Body) + if err != nil { + return fmt.Errorf("failed to read gzip object %s/%s: %w", bucket, key, err) + } + defer gzReader.Close() + scanner = bufio.NewScanner(gzReader) + } else { + scanner = bufio.NewScanner(output.Body) + } + for scanner.Scan() { + text := scanner.Text() + logger.Tracef("Read line %s", text) + linesRead.WithLabelValues(bucket).Inc() + l := types.Line{} + l.Raw = text + l.Labels = s.Config.Labels + l.Time = time.Now().UTC() + l.Process = true + l.Module = s.GetName() + l.Src = bucket + var evt types.Event + if !s.Config.UseTimeMachine { + evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.LIVE} + } else { + evt = types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: types.TIMEMACHINE} + } + s.out <- evt + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to read object %s/%s: %s", bucket, key, err) + } + objectsRead.WithLabelValues(bucket).Inc() + return nil +} + +func (s *S3Source) GetMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived} +} +func (s *S3Source) GetAggregMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead, objectsRead, sqsMessagesReceived} +} + +func (s *S3Source) UnmarshalConfig(yamlConfig []byte) error { + s.Config = S3Configuration{} + err := yaml.UnmarshalStrict(yamlConfig, &s.Config) + if err != nil { + return fmt.Errorf("cannot parse S3Acquisition configuration: %w", err) + } + if s.Config.Mode == "" { + s.Config.Mode = configuration.TAIL_MODE + } + if s.Config.PollingMethod == "" { + s.Config.PollingMethod = PollMethodList + } + + if s.Config.PollingInterval == 0 { + s.Config.PollingInterval = 60 + } + + if s.Config.PollingMethod != PollMethodList && s.Config.PollingMethod != PollMethodSQS { + return fmt.Errorf("invalid polling method %s", s.Config.PollingMethod) + } + + if s.Config.BucketName != "" && s.Config.SQSName != "" { + return fmt.Errorf("bucket_name and sqs_name are mutually exclusive") + } + + if s.Config.PollingMethod == PollMethodSQS && s.Config.SQSName == "" { + return fmt.Errorf("sqs_name is required when using sqs polling method") + } + + if s.Config.BucketName == "" && s.Config.PollingMethod == PollMethodList { + return fmt.Errorf("bucket_name is required") + } + + if s.Config.SQSFormat != "" && s.Config.SQSFormat != SQSFormatEventBridge && s.Config.SQSFormat != SQSFormatS3Notification { + return fmt.Errorf("invalid sqs_format %s, must be empty, %s or %s", s.Config.SQSFormat, SQSFormatEventBridge, SQSFormatS3Notification) + } + + return nil +} + +func (s *S3Source) Configure(yamlConfig []byte, logger *log.Entry) error { + err := s.UnmarshalConfig(yamlConfig) + if err != nil { + return err + } + + if s.Config.SQSName != "" { + s.logger = logger.WithFields(log.Fields{ + "queue": s.Config.SQSName, + }) + } else { + s.logger = logger.WithFields(log.Fields{ + "bucket": s.Config.BucketName, + "prefix": s.Config.Prefix, + }) + } + + if !s.Config.UseTimeMachine { + s.logger.Warning("use_time_machine is not set to true in the datasource configuration. This will likely lead to false positives as S3 logs are not processed in real time.") + } + + if s.Config.PollingMethod == PollMethodList { + s.logger.Warning("Polling method is set to list. This is not recommended as it will not scale well. Consider using SQS instead.") + } + + err = s.newS3Client() + if err != nil { + return err + } + + if s.Config.PollingMethod == PollMethodSQS { + err = s.newSQSClient() + if err != nil { + return err + } + } + + return nil +} + +func (s *S3Source) ConfigureByDSN(dsn string, labels map[string]string, logger *log.Entry) error { + if !strings.HasPrefix(dsn, "s3://") { + return fmt.Errorf("invalid DSN %s for S3 source, must start with s3://", dsn) + } + + s.logger = logger.WithFields(log.Fields{ + "bucket": s.Config.BucketName, + "prefix": s.Config.Prefix, + }) + + dsn = strings.TrimPrefix(dsn, "s3://") + args := strings.Split(dsn, "?") + if len(args[0]) == 0 { + return fmt.Errorf("empty s3:// DSN") + } + + if len(args) == 2 && len(args[1]) != 0 { + params, err := url.ParseQuery(args[1]) + if err != nil { + return errors.Wrap(err, "could not parse s3 args") + } + for key, value := range params { + if key != "log_level" { + return fmt.Errorf("unsupported key %s in s3 DSN", key) + } + if len(value) != 1 { + return errors.New("expected zero or one value for 'log_level'") + } + lvl, err := log.ParseLevel(value[0]) + if err != nil { + return errors.Wrapf(err, "unknown level %s", value[0]) + } + s.logger.Logger.SetLevel(lvl) + } + } + + s.Config = S3Configuration{} + s.Config.Labels = labels + s.Config.Mode = configuration.CAT_MODE + + pathParts := strings.Split(args[0], "/") + s.logger.Debugf("pathParts: %v", pathParts) + + //FIXME: handle s3://bucket/ + if len(pathParts) == 1 { + s.Config.BucketName = pathParts[0] + s.Config.Prefix = "" + } else if len(pathParts) > 1 { + s.Config.BucketName = pathParts[0] + if args[0][len(args[0])-1] == '/' { + s.Config.Prefix = strings.Join(pathParts[1:], "/") + } else { + s.Config.Key = strings.Join(pathParts[1:], "/") + } + } else { + return fmt.Errorf("invalid DSN %s for S3 source", dsn) + } + + err := s.newS3Client() + if err != nil { + return err + } + + return nil +} + +func (s *S3Source) GetMode() string { + return s.Config.Mode +} + +func (s *S3Source) GetName() string { + return "s3" +} + +func (s *S3Source) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { + s.logger.Infof("starting acquisition of %s/%s/%s", s.Config.BucketName, s.Config.Prefix, s.Config.Key) + s.out = out + s.ctx, s.cancel = context.WithCancel(context.Background()) + if s.Config.Key != "" { + err := s.readFile(s.Config.BucketName, s.Config.Key) + if err != nil { + return err + } + } else { + //No key, get everything in the bucket based on the prefix + objects, err := s.getBucketContent() + if err != nil { + return err + } + for _, object := range objects { + err := s.readFile(s.Config.BucketName, *object.Key) + if err != nil { + return err + } + } + } + return nil +} + +func (s *S3Source) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { + s.t = t + s.out = out + s.readerChan = make(chan S3Object, 100) //FIXME: does this needs to be buffered? + s.ctx, s.cancel = context.WithCancel(context.Background()) + s.logger.Infof("starting acquisition of %s/%s", s.Config.BucketName, s.Config.Prefix) + t.Go(func() error { + s.readManager() + return nil + }) + if s.Config.PollingMethod == PollMethodSQS { + t.Go(func() error { + err := s.sqsPoll() + if err != nil { + return err + } + return nil + }) + } else { + t.Go(func() error { + err := s.listPoll() + if err != nil { + return err + } + return nil + }) + } + return nil +} + +func (s *S3Source) CanRun() error { + return nil +} + +func (s *S3Source) Dump() interface{} { + return s +} diff --git a/pkg/acquisition/modules/s3/s3_test.go b/pkg/acquisition/modules/s3/s3_test.go new file mode 100644 index 000000000..5f7deeda1 --- /dev/null +++ b/pkg/acquisition/modules/s3/s3_test.go @@ -0,0 +1,429 @@ +package s3acquisition + +import ( + "context" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/crowdsecurity/crowdsec/pkg/types" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "gopkg.in/tomb.v2" +) + +func TestBadConfiguration(t *testing.T) { + tests := []struct { + name string + config string + expectedErr string + }{ + { + name: "no bucket", + config: ` +source: s3 +`, + expectedErr: "bucket_name is required", + }, + { + name: "invalid polling method", + config: ` +source: s3 +bucket_name: foobar +polling_method: foobar +`, + expectedErr: "invalid polling method foobar", + }, + { + name: "no sqs name", + config: ` +source: s3 +bucket_name: foobar +polling_method: sqs +`, + expectedErr: "sqs_name is required when using sqs polling method", + }, + { + name: "both bucket and sqs", + config: ` +source: s3 +bucket_name: foobar +polling_method: sqs +sqs_name: foobar +`, + expectedErr: "bucket_name and sqs_name are mutually exclusive", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := S3Source{} + err := f.Configure([]byte(test.config), nil) + if err == nil { + t.Fatalf("expected error, got none") + } + if err.Error() != test.expectedErr { + t.Fatalf("expected error %s, got %s", test.expectedErr, err.Error()) + } + }) + } +} + +func TestGoodConfiguration(t *testing.T) { + tests := []struct { + name string + config string + }{ + { + name: "basic", + config: ` +source: s3 +bucket_name: foobar +`, + }, + { + name: "polling method", + config: ` +source: s3 +polling_method: sqs +sqs_name: foobar +`, + }, + { + name: "list method", + config: ` +source: s3 +bucket_name: foobar +polling_method: list +`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := S3Source{} + logger := log.NewEntry(log.New()) + err := f.Configure([]byte(test.config), logger) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + }) + } +} + +type mockS3Client struct { + s3iface.S3API +} + +// We add one hour to trick the listing goroutine into thinking the files are new +var mockListOutput map[string][]*s3.Object = map[string][]*s3.Object{ + "bucket_no_prefix": { + { + Key: aws.String("foo.log"), + LastModified: aws.Time(time.Now().Add(time.Hour)), + }, + }, + "bucket_with_prefix": { + { + Key: aws.String("prefix/foo.log"), + LastModified: aws.Time(time.Now().Add(time.Hour)), + }, + { + Key: aws.String("prefix/bar.log"), + LastModified: aws.Time(time.Now().Add(time.Hour)), + }, + }, +} + +func (m mockS3Client) ListObjectsV2WithContext(ctx context.Context, input *s3.ListObjectsV2Input, options ...request.Option) (*s3.ListObjectsV2Output, error) { + log.Infof("returning mock list output for %s, %v", *input.Bucket, mockListOutput[*input.Bucket]) + return &s3.ListObjectsV2Output{ + Contents: mockListOutput[*input.Bucket], + }, nil +} + +func (m mockS3Client) GetObjectWithContext(ctx context.Context, input *s3.GetObjectInput, options ...request.Option) (*s3.GetObjectOutput, error) { + r := strings.NewReader("foo\nbar") + return &s3.GetObjectOutput{ + Body: aws.ReadSeekCloser(r), + }, nil +} + +type mockSQSClient struct { + sqsiface.SQSAPI + counter *int32 +} + +func (msqs mockSQSClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) { + if atomic.LoadInt32(msqs.counter) == 1 { + return &sqs.ReceiveMessageOutput{}, nil + } + atomic.AddInt32(msqs.counter, 1) + return &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + Body: aws.String(` +{"version":"0","id":"af1ce7ea-bdb4-5bb7-3af2-c6cb32f9aac9","detail-type":"Object Created","source":"aws.s3","account":"1234","time":"2023-03-17T07:45:04Z","region":"eu-west-1","resources":["arn:aws:s3:::my_bucket"],"detail":{"version":"0","bucket":{"name":"my_bucket"},"object":{"key":"foo.log","size":663,"etag":"f2d5268a0776d6cdd6e14fcfba96d1cd","sequencer":"0064141A8022966874"},"request-id":"MBWX2P6FWA3S1YH5","requester":"156460612806","source-ip-address":"42.42.42.42","reason":"PutObject"}}`), + }, + }, + }, nil +} + +func (msqs mockSQSClient) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) { + return &sqs.DeleteMessageOutput{}, nil +} + +type mockSQSClientNotif struct { + sqsiface.SQSAPI + counter *int32 +} + +func (msqs mockSQSClientNotif) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, options ...request.Option) (*sqs.ReceiveMessageOutput, error) { + if atomic.LoadInt32(msqs.counter) == 1 { + return &sqs.ReceiveMessageOutput{}, nil + } + atomic.AddInt32(msqs.counter, 1) + return &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + Body: aws.String(` + {"Records":[{"eventVersion":"2.1","eventSource":"aws:s3","awsRegion":"eu-west-1","eventTime":"2023-03-20T19:30:02.536Z","eventName":"ObjectCreated:Put","userIdentity":{"principalId":"AWS:XXXXX"},"requestParameters":{"sourceIPAddress":"42.42.42.42"},"responseElements":{"x-amz-request-id":"FM0TAV2WE5AXXW42","x-amz-id-2":"LCfQt1aSBtD1G5wdXjB5ANdPxLEXJxA89Ev+/rRAsCGFNJGI/1+HMlKI59S92lqvzfViWh7B74leGKWB8/nNbsbKbK7WXKz2"},"s3":{"s3SchemaVersion":"1.0","configurationId":"test-acquis","bucket":{"name":"my_bucket","ownerIdentity":{"principalId":"A1F2PSER1FB8MY"},"arn":"arn:aws:s3:::my_bucket"},"object":{"key":"foo.log","size":3097,"eTag":"ab6889744611c77991cbc6ca12d1ddc7","sequencer":"006418B43A76BC0257"}}}]}`), + }, + }, + }, nil +} + +func (msqs mockSQSClientNotif) DeleteMessage(input *sqs.DeleteMessageInput) (*sqs.DeleteMessageOutput, error) { + return &sqs.DeleteMessageOutput{}, nil +} + +func TestDSNAcquis(t *testing.T) { + tests := []struct { + name string + dsn string + expectedBucketName string + expectedPrefix string + expectedCount int + }{ + { + name: "basic", + dsn: "s3://bucket_no_prefix/foo.log", + expectedBucketName: "bucket_no_prefix", + expectedPrefix: "", + expectedCount: 2, + }, + { + name: "with prefix", + dsn: "s3://bucket_with_prefix/prefix/", + expectedBucketName: "bucket_with_prefix", + expectedPrefix: "prefix/", + expectedCount: 4, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + linesRead := 0 + f := S3Source{} + logger := log.NewEntry(log.New()) + err := f.ConfigureByDSN(test.dsn, map[string]string{"foo": "bar"}, logger) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + assert.Equal(t, test.expectedBucketName, f.Config.BucketName) + assert.Equal(t, test.expectedPrefix, f.Config.Prefix) + out := make(chan types.Event) + + done := make(chan bool) + + go func() { + for { + select { + case s := <-out: + fmt.Printf("got line %s\n", s.Line.Raw) + linesRead++ + case <-done: + return + } + } + }() + + f.s3Client = mockS3Client{} + err = f.OneShotAcquisition(out, nil) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + time.Sleep(2 * time.Second) + done <- true + assert.Equal(t, test.expectedCount, linesRead) + + }) + } + +} + +func TestListPolling(t *testing.T) { + tests := []struct { + name string + config string + expectedCount int + }{ + { + name: "basic", + config: ` +source: s3 +bucket_name: bucket_no_prefix +polling_method: list +polling_interval: 1 +`, + expectedCount: 2, + }, + { + name: "with prefix", + config: ` +source: s3 +bucket_name: bucket_with_prefix +polling_method: list +polling_interval: 1 +prefix: foo/ +`, + expectedCount: 4, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + linesRead := 0 + f := S3Source{} + logger := log.NewEntry(log.New()) + logger.Logger.SetLevel(log.TraceLevel) + err := f.Configure([]byte(test.config), logger) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if f.Config.PollingMethod != PollMethodList { + t.Fatalf("expected list polling, got %s", f.Config.PollingMethod) + } + + f.s3Client = mockS3Client{} + + out := make(chan types.Event) + tb := tomb.Tomb{} + + go func() { + for { + select { + case s := <-out: + fmt.Printf("got line %s\n", s.Line.Raw) + linesRead++ + case <-tb.Dying(): + return + } + } + }() + + err = f.StreamingAcquisition(out, &tb) + + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + time.Sleep(2 * time.Second) + tb.Kill(nil) + err = tb.Wait() + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + assert.Equal(t, test.expectedCount, linesRead) + }) + } +} + +func TestSQSPoll(t *testing.T) { + tests := []struct { + name string + config string + notifType string + expectedCount int + }{ + { + name: "eventbridge", + config: ` +source: s3 +polling_method: sqs +sqs_name: test +`, + expectedCount: 2, + notifType: "eventbridge", + }, + { + name: "notification", + config: ` +source: s3 +polling_method: sqs +sqs_name: test +`, + expectedCount: 2, + notifType: "notification", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + linesRead := 0 + f := S3Source{} + logger := log.NewEntry(log.New()) + err := f.Configure([]byte(test.config), logger) + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + if f.Config.PollingMethod != PollMethodSQS { + t.Fatalf("expected sqs polling, got %s", f.Config.PollingMethod) + } + + counter := int32(0) + f.s3Client = mockS3Client{} + if test.notifType == "eventbridge" { + f.sqsClient = mockSQSClient{counter: &counter} + } else { + f.sqsClient = mockSQSClientNotif{counter: &counter} + } + + out := make(chan types.Event) + tb := tomb.Tomb{} + + go func() { + for { + select { + case s := <-out: + fmt.Printf("got line %s\n", s.Line.Raw) + linesRead++ + case <-tb.Dying(): + return + } + } + }() + + err = f.StreamingAcquisition(out, &tb) + + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + + time.Sleep(2 * time.Second) + tb.Kill(nil) + err = tb.Wait() + if err != nil { + t.Fatalf("unexpected error: %s", err.Error()) + } + assert.Equal(t, test.expectedCount, linesRead) + }) + } +}