diff --git a/.github/workflows/ci_go-test.yml b/.github/workflows/ci_go-test.yml index 76bb9f495..0e9dc7922 100644 --- a/.github/workflows/ci_go-test.yml +++ b/.github/workflows/ci_go-test.yml @@ -3,13 +3,14 @@ name: tests #those env variables are for localstack, so we can emulate aws services env: AWS_HOST: localstack - SERVICES: cloudwatch,logs + SERVICES: cloudwatch,logs,kinesis #those are to mimic aws config AWS_ACCESS_KEY_ID: AKIAIOSFODNN7EXAMPLE AWS_SECRET_ACCESS_KEY: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY AWS_REGION: us-east-1 #and to override our endpoint in aws sdk - AWS_ENDPOINT_FORCE: http://localhost:4566 + AWS_ENDPOINT_FORCE: http://localhost:4566 + KINESIS_INITIALIZE_STREAMS: "stream-1-shard:1,stream-2-shards:2" on: push: @@ -32,7 +33,7 @@ jobs: runs-on: ubuntu-latest services: localstack: - image: localstack/localstack:0.12.11 + image: localstack/localstack:0.13.3 ports: - 4566:4566 # Localstack exposes all services on same port env: @@ -43,6 +44,7 @@ jobs: KINESIS_ERROR_PROBABILITY: "" DOCKER_HOST: unix:///var/run/docker.sock HOST_TMP_FOLDER: "/tmp" + KINESIS_INITIALIZE_STREAMS: ${{ env.KINESIS_INITIALIZE_STREAMS }} HOSTNAME_EXTERNAL: ${{ env.AWS_HOST }} # Required so that resource urls are provided properly # e.g sqs url will get localhost if we don't set this env to map our service options: >- diff --git a/go.mod b/go.mod index 6d760a107..f0334a5ea 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/alexliesenfeld/health v0.5.1 github.com/antonmedv/expr v1.8.9 github.com/appleboy/gin-jwt/v2 v2.6.4 - github.com/aws/aws-sdk-go v1.38.34 + github.com/aws/aws-sdk-go v1.42.25 github.com/buger/jsonparser v1.1.1 github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf github.com/crowdsecurity/grokky v0.0.0-20210908095311-0b3373925934 @@ -129,10 +129,10 @@ require ( github.com/vjeantet/grok v1.0.1 // indirect github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect go.mongodb.org/mongo-driver v1.4.4 // indirect - golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 // indirect + golang.org/x/net v0.0.0-20211209124913-491a49abca63 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/term v0.0.0-20201210144234-2321bbc49cbf // indirect - golang.org/x/text v0.3.5 // indirect + golang.org/x/text v0.3.6 // indirect google.golang.org/appengine v1.6.6 // indirect google.golang.org/genproto v0.0.0-20210114201628-6edceaf6022f // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect diff --git a/go.sum b/go.sum index 56b4ed277..b94c5ff3b 100644 --- a/go.sum +++ b/go.sum @@ -80,6 +80,8 @@ github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= github.com/aws/aws-sdk-go v1.38.34 h1:JSAyS6hSDLbRmCAz9VAkwDf5oh/olt9mBTrVBWGJcU8= github.com/aws/aws-sdk-go v1.38.34/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= +github.com/aws/aws-sdk-go v1.42.25 h1:BbdvHAi+t9LRiaYUyd53noq9jcaAcfzOhSVbKfr6Avs= +github.com/aws/aws-sdk-go v1.42.25/go.mod h1:gyRszuZ/icHmHAVE4gc/r+cfCmhA1AD+vqfWbgI+eHs= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= @@ -816,6 +818,8 @@ golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20211209124913-491a49abca63 h1:iocB37TsdFuN6IBRZ+ry36wrkoV51/tl5vOWqkcPGvY= +golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -876,6 +880,7 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210309074719-68d13333faf2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210921065528-437939a70204 h1:JJhkWtBuTQKyz2bd5WG9H8iUsJRU3En/KRfN8B2RnDs= @@ -891,6 +896,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5 h1:i6eZZ+zk0SOf0xgBpEpPD18qWcJda6q1sxt3S0kzyUQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/pkg/acquisition/acquisition.go b/pkg/acquisition/acquisition.go index 8bd04b86f..648417c5c 100644 --- a/pkg/acquisition/acquisition.go +++ b/pkg/acquisition/acquisition.go @@ -11,6 +11,7 @@ import ( dockeracquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/docker" fileacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/file" journalctlacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/journalctl" + kinesisacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/kinesis" syslogacquisition "github.com/crowdsecurity/crowdsec/pkg/acquisition/modules/syslog" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -60,6 +61,10 @@ var AcquisitionSources = []struct { name: "docker", iface: func() DataSource { return &dockeracquisition.DockerSource{} }, }, + { + name: "kinesis", + iface: func() DataSource { return &kinesisacquisition.KinesisSource{} }, + }, } func GetDataSourceIface(dataSourceType string) DataSource { diff --git a/pkg/acquisition/modules/kinesis/kinesis.go b/pkg/acquisition/modules/kinesis/kinesis.go new file mode 100644 index 000000000..1e68f4f1a --- /dev/null +++ b/pkg/acquisition/modules/kinesis/kinesis.go @@ -0,0 +1,510 @@ +package kinesisacquisition + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "io/ioutil" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/crowdsecurity/crowdsec/pkg/acquisition/configuration" + "github.com/crowdsecurity/crowdsec/pkg/leakybucket" + "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/pkg/errors" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "gopkg.in/tomb.v2" + "gopkg.in/yaml.v2" +) + +type KinesisConfiguration struct { + configuration.DataSourceCommonCfg `yaml:",inline"` + StreamName string `yaml:"stream_name"` + StreamARN string `yaml:"stream_arn"` + UseEnhancedFanOut bool `yaml:"use_enhanced_fanout"` //Use RegisterStreamConsumer and SubscribeToShard instead of GetRecords + AwsProfile *string `yaml:"aws_profile"` + AwsRegion string `yaml:"aws_region"` + AwsEndpoint string `yaml:"aws_endpoint"` + ConsumerName string `yaml:"consumer_name"` + FromSubscription bool `yaml:"from_subscription"` + MaxRetries int `yaml:"max_retries"` +} + +type KinesisSource struct { + Config KinesisConfiguration + logger *log.Entry + kClient *kinesis.Kinesis + shardReaderTomb *tomb.Tomb +} + +type CloudWatchSubscriptionRecord struct { + MessageType string `json:"messageType"` + Owner string `json:"owner"` + LogGroup string `json:"logGroup"` + LogStream string `json:"logStream"` + SubscriptionFilters []string `json:"subscriptionFilters"` + LogEvents []CloudwatchSubscriptionLogEvent `json:"logEvents"` +} + +type CloudwatchSubscriptionLogEvent struct { + ID string `json:"id"` + Message string `json:"message"` + Timestamp int64 `json:"timestamp"` +} + +var linesRead = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_kinesis_stream_hits_total", + Help: "Number of event read per stream.", + }, + []string{"stream"}, +) + +var linesReadShards = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Name: "cs_kinesis_shards_hits_total", + Help: "Number of event read per shards.", + }, + []string{"stream", "shard"}, +) + +func (k *KinesisSource) newClient() error { + var sess *session.Session + + if k.Config.AwsProfile != nil { + sess = session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + Profile: *k.Config.AwsProfile, + })) + } else { + sess = session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + } + + if sess == nil { + return fmt.Errorf("failed to create aws session") + } + config := aws.NewConfig() + if k.Config.AwsRegion != "" { + config = config.WithRegion(k.Config.AwsRegion) + } + if k.Config.AwsEndpoint != "" { + config = config.WithEndpoint(k.Config.AwsEndpoint) + } + k.kClient = kinesis.New(sess, config) + if k.kClient == nil { + return fmt.Errorf("failed to create kinesis client") + } + return nil +} + +func (k *KinesisSource) GetMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead, linesReadShards} + +} +func (k *KinesisSource) GetAggregMetrics() []prometheus.Collector { + return []prometheus.Collector{linesRead, linesReadShards} +} + +func (k *KinesisSource) Configure(yamlConfig []byte, logger *log.Entry) error { + config := KinesisConfiguration{} + k.logger = logger + err := yaml.UnmarshalStrict(yamlConfig, &config) + if err != nil { + return errors.Wrap(err, "Cannot parse kinesis datasource configuration") + } + if config.Mode == "" { + config.Mode = configuration.TAIL_MODE + } + k.Config = config + if k.Config.StreamName == "" && !k.Config.UseEnhancedFanOut { + return fmt.Errorf("stream_name is mandatory when use_enhanced_fanout is false") + } + if k.Config.StreamARN == "" && k.Config.UseEnhancedFanOut { + return fmt.Errorf("stream_arn is mandatory when use_enhanced_fanout is true") + } + if k.Config.ConsumerName == "" && k.Config.UseEnhancedFanOut { + return fmt.Errorf("consumer_name is mandatory when use_enhanced_fanout is true") + } + if k.Config.StreamARN != "" && k.Config.StreamName != "" { + return fmt.Errorf("stream_arn and stream_name are mutually exclusive") + } + if k.Config.MaxRetries <= 0 { + k.Config.MaxRetries = 10 + } + err = k.newClient() + if err != nil { + return errors.Wrap(err, "Cannot create kinesis client") + } + k.shardReaderTomb = &tomb.Tomb{} + return nil +} + +func (k *KinesisSource) ConfigureByDSN(string, map[string]string, *log.Entry) error { + return fmt.Errorf("kinesis datasource does not support command-line acquisition") +} + +func (k *KinesisSource) GetMode() string { + return k.Config.Mode +} + +func (k *KinesisSource) GetName() string { + return "kinesis" +} + +func (k *KinesisSource) OneShotAcquisition(out chan types.Event, t *tomb.Tomb) error { + return fmt.Errorf("kinesis datasource does not support one-shot acquisition") +} + +func (k *KinesisSource) decodeFromSubscription(record []byte) ([]CloudwatchSubscriptionLogEvent, error) { + b := bytes.NewBuffer(record) + r, err := gzip.NewReader(b) + + if err != nil { + k.logger.Error(err) + return nil, err + } + decompressed, err := ioutil.ReadAll(r) + if err != nil { + k.logger.Error(err) + return nil, err + } + var subscriptionRecord CloudWatchSubscriptionRecord + err = json.Unmarshal(decompressed, &subscriptionRecord) + if err != nil { + k.logger.Error(err) + return nil, err + } + return subscriptionRecord.LogEvents, nil +} + +func (k *KinesisSource) WaitForConsumerDeregistration(consumerName string, streamARN string) error { + maxTries := k.Config.MaxRetries + for i := 0; i < maxTries; i++ { + _, err := k.kClient.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{ + ConsumerName: aws.String(consumerName), + StreamARN: aws.String(streamARN), + }) + if err != nil { + switch err.(type) { + case *kinesis.ResourceNotFoundException: + return nil + default: + k.logger.Errorf("Error while waiting for consumer deregistration: %s", err) + return errors.Wrap(err, "Cannot describe stream consumer") + } + } + time.Sleep(time.Millisecond * 200 * time.Duration(i+1)) + } + return fmt.Errorf("consumer %s is not deregistered after %d tries", consumerName, maxTries) +} + +func (k *KinesisSource) DeregisterConsumer() error { + k.logger.Debugf("Deregistering consumer %s if it exists", k.Config.ConsumerName) + _, err := k.kClient.DeregisterStreamConsumer(&kinesis.DeregisterStreamConsumerInput{ + ConsumerName: aws.String(k.Config.ConsumerName), + StreamARN: aws.String(k.Config.StreamARN), + }) + if err != nil { + switch err.(type) { + case *kinesis.ResourceNotFoundException: + default: + return errors.Wrap(err, "Cannot deregister stream consumer") + } + } + err = k.WaitForConsumerDeregistration(k.Config.ConsumerName, k.Config.StreamARN) + if err != nil { + return errors.Wrap(err, "Cannot wait for consumer deregistration") + } + return nil +} + +func (k *KinesisSource) WaitForConsumerRegistration(consumerARN string) error { + maxTries := k.Config.MaxRetries + for i := 0; i < maxTries; i++ { + describeOutput, err := k.kClient.DescribeStreamConsumer(&kinesis.DescribeStreamConsumerInput{ + ConsumerARN: aws.String(consumerARN), + }) + if err != nil { + return errors.Wrap(err, "Cannot describe stream consumer") + } + if *describeOutput.ConsumerDescription.ConsumerStatus == "ACTIVE" { + k.logger.Debugf("Consumer %s is active", consumerARN) + return nil + } + time.Sleep(time.Millisecond * 200 * time.Duration(i+1)) + k.logger.Debugf("Waiting for consumer registration %d", i) + } + return fmt.Errorf("consumer %s is not active after %d tries", consumerARN, maxTries) +} + +func (k *KinesisSource) RegisterConsumer() (*kinesis.RegisterStreamConsumerOutput, error) { + k.logger.Debugf("Registering consumer %s", k.Config.ConsumerName) + streamConsumer, err := k.kClient.RegisterStreamConsumer(&kinesis.RegisterStreamConsumerInput{ + ConsumerName: aws.String(k.Config.ConsumerName), + StreamARN: aws.String(k.Config.StreamARN), + }) + if err != nil { + return nil, errors.Wrap(err, "Cannot register stream consumer") + } + err = k.WaitForConsumerRegistration(*streamConsumer.Consumer.ConsumerARN) + if err != nil { + return nil, errors.Wrap(err, "Timeout while waiting for consumer to be active") + } + return streamConsumer, nil +} + +func (k *KinesisSource) ParseAndPushRecords(records []*kinesis.Record, out chan types.Event, logger *log.Entry, shardId string) { + for _, record := range records { + if k.Config.StreamARN != "" { + linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamARN, "shard": shardId}).Inc() + linesRead.With(prometheus.Labels{"stream": k.Config.StreamARN}).Inc() + } else { + linesReadShards.With(prometheus.Labels{"stream": k.Config.StreamName, "shard": shardId}).Inc() + linesRead.With(prometheus.Labels{"stream": k.Config.StreamName}).Inc() + } + var data []CloudwatchSubscriptionLogEvent + var err error + if k.Config.FromSubscription { + //The AWS docs says that the data is base64 encoded + //but apparently GetRecords decodes it for us ? + data, err = k.decodeFromSubscription(record.Data) + if err != nil { + logger.Errorf("Cannot decode data: %s", err) + continue + } + } else { + data = []CloudwatchSubscriptionLogEvent{{Message: string(record.Data)}} + } + for _, event := range data { + logger.Tracef("got record %s", event.Message) + l := types.Line{} + l.Raw = event.Message + l.Labels = k.Config.Labels + l.Time = time.Now() + l.Process = true + l.Module = k.GetName() + if k.Config.StreamARN != "" { + l.Src = k.Config.StreamARN + } else { + l.Src = k.Config.StreamName + } + evt := types.Event{Line: l, Process: true, Type: types.LOG, ExpectMode: leakybucket.LIVE} + out <- evt + } + } +} + +func (k *KinesisSource) ReadFromSubscription(reader kinesis.SubscribeToShardEventStreamReader, out chan types.Event, shardId string, streamName string) error { + logger := k.logger.WithFields(log.Fields{"shard_id": shardId}) + //ghetto sync, kinesis allows to subscribe to a closed shard, which will make the goroutine exit immediately + //and we won't be able to start a new one if this is the first one started by the tomb + //TODO: look into parent shards to see if a shard is closed before starting to read it ? + time.Sleep(time.Second) + for { + select { + case <-k.shardReaderTomb.Dying(): + logger.Infof("Subscribed shard reader is dying") + err := reader.Close() + if err != nil { + return errors.Wrap(err, "Cannot close kinesis subscribed shard reader") + } + return nil + case event, ok := <-reader.Events(): + if !ok { + logger.Infof("Event chan has been closed") + return nil + } + switch event := event.(type) { + case *kinesis.SubscribeToShardEvent: + k.ParseAndPushRecords(event.Records, out, logger, shardId) + case *kinesis.SubscribeToShardEventStreamUnknownEvent: + logger.Infof("got an unknown event, what to do ?") + } + } + } +} + +func (k *KinesisSource) SubscribeToShards(arn arn.ARN, streamConsumer *kinesis.RegisterStreamConsumerOutput, out chan types.Event) error { + shards, err := k.kClient.ListShards(&kinesis.ListShardsInput{ + StreamName: aws.String(arn.Resource[7:]), + }) + if err != nil { + return errors.Wrap(err, "Cannot list shards for enhanced_read") + } + + for _, shard := range shards.Shards { + shardId := *shard.ShardId + r, err := k.kClient.SubscribeToShard(&kinesis.SubscribeToShardInput{ + ShardId: aws.String(shardId), + StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)}, + ConsumerARN: streamConsumer.Consumer.ConsumerARN, + }) + if err != nil { + return errors.Wrap(err, "Cannot subscribe to shard") + } + k.shardReaderTomb.Go(func() error { + return k.ReadFromSubscription(r.GetEventStream().Reader, out, shardId, arn.Resource[7:]) + }) + } + return nil +} + +func (k *KinesisSource) EnhancedRead(out chan types.Event, t *tomb.Tomb) error { + parsedARN, err := arn.Parse(k.Config.StreamARN) + if err != nil { + return errors.Wrap(err, "Cannot parse stream ARN") + } + if !strings.HasPrefix(parsedARN.Resource, "stream/") { + return fmt.Errorf("resource part of stream ARN %s does not start with stream/", k.Config.StreamARN) + } + + k.logger = k.logger.WithFields(log.Fields{"stream": parsedARN.Resource[7:]}) + k.logger.Info("starting kinesis acquisition with enhanced fan-out") + err = k.DeregisterConsumer() + if err != nil { + return errors.Wrap(err, "Cannot deregister consumer") + } + + streamConsumer, err := k.RegisterConsumer() + if err != nil { + return errors.Wrap(err, "Cannot register consumer") + } + + for { + k.shardReaderTomb = &tomb.Tomb{} + + err = k.SubscribeToShards(parsedARN, streamConsumer, out) + if err != nil { + return errors.Wrap(err, "Cannot subscribe to shards") + } + select { + case <-t.Dying(): + k.logger.Infof("Kinesis source is dying") + k.shardReaderTomb.Kill(nil) + _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + err = k.DeregisterConsumer() + if err != nil { + return errors.Wrap(err, "Cannot deregister consumer") + } + return nil + case <-k.shardReaderTomb.Dying(): + k.logger.Debugf("Kinesis subscribed shard reader is dying") + if k.shardReaderTomb.Err() != nil { + return k.shardReaderTomb.Err() + } + //All goroutines have exited without error, so a resharding event, start again + k.logger.Debugf("All reader goroutines have exited, resharding event or periodic resubscribe") + continue + } + } +} + +func (k *KinesisSource) ReadFromShard(out chan types.Event, shardId string) error { + logger := k.logger.WithFields(log.Fields{"shard": shardId}) + logger.Debugf("Starting to read shard") + sharIt, err := k.kClient.GetShardIterator(&kinesis.GetShardIteratorInput{ShardId: aws.String(shardId), + StreamName: &k.Config.StreamName, + ShardIteratorType: aws.String(kinesis.ShardIteratorTypeLatest)}) + if err != nil { + logger.Errorf("Cannot get shard iterator: %s", err) + return errors.Wrap(err, "Cannot get shard iterator") + } + it := sharIt.ShardIterator + //AWS recommends to wait for a second between calls to GetRecords for a given shard + ticker := time.NewTicker(time.Second) + for { + select { + case <-ticker.C: + records, err := k.kClient.GetRecords(&kinesis.GetRecordsInput{ShardIterator: it}) + it = records.NextShardIterator + if err != nil { + switch err.(type) { + case *kinesis.ProvisionedThroughputExceededException: + logger.Warn("Provisioned throughput exceeded") + //TODO: implement exponential backoff + continue + case *kinesis.ExpiredIteratorException: + logger.Warn("Expired iterator") + continue + default: + logger.Error("Cannot get records") + return errors.Wrap(err, "Cannot get records") + } + } + k.ParseAndPushRecords(records.Records, out, logger, shardId) + + if it == nil { + logger.Warnf("Shard has been closed") + return nil + } + case <-k.shardReaderTomb.Dying(): + logger.Infof("shardReaderTomb is dying, exiting ReadFromShard") + ticker.Stop() + return nil + } + } +} + +func (k *KinesisSource) ReadFromStream(out chan types.Event, t *tomb.Tomb) error { + k.logger = k.logger.WithFields(log.Fields{"stream": k.Config.StreamName}) + k.logger.Info("starting kinesis acquisition from shards") + for { + shards, err := k.kClient.ListShards(&kinesis.ListShardsInput{ + StreamName: aws.String(k.Config.StreamName), + }) + if err != nil { + return errors.Wrap(err, "Cannot list shards") + } + k.shardReaderTomb = &tomb.Tomb{} + for _, shard := range shards.Shards { + shardId := *shard.ShardId + k.shardReaderTomb.Go(func() error { + defer types.CatchPanic("crowdsec/acquis/kinesis/streaming/shard") + return k.ReadFromShard(out, shardId) + }) + } + select { + case <-t.Dying(): + k.logger.Info("kinesis source is dying") + k.shardReaderTomb.Kill(nil) + _ = k.shardReaderTomb.Wait() //we don't care about the error as we kill the tomb ourselves + return nil + case <-k.shardReaderTomb.Dying(): + reason := k.shardReaderTomb.Err() + if reason != nil { + k.logger.Errorf("Unexpected error from shard reader : %s", reason) + return reason + } + k.logger.Infof("All shards have been closed, probably a resharding event, restarting acquisition") + continue + } + } +} + +func (k *KinesisSource) StreamingAcquisition(out chan types.Event, t *tomb.Tomb) error { + t.Go(func() error { + defer types.CatchPanic("crowdsec/acquis/kinesis/streaming") + if k.Config.UseEnhancedFanOut { + return k.EnhancedRead(out, t) + } else { + return k.ReadFromStream(out, t) + } + }) + return nil +} + +func (k *KinesisSource) CanRun() error { + return nil +} + +func (k *KinesisSource) Dump() interface{} { + return k +} diff --git a/pkg/acquisition/modules/kinesis/kinesis_test.go b/pkg/acquisition/modules/kinesis/kinesis_test.go new file mode 100644 index 000000000..c17e9a9ca --- /dev/null +++ b/pkg/acquisition/modules/kinesis/kinesis_test.go @@ -0,0 +1,325 @@ +package kinesisacquisition + +import ( + "bytes" + "compress/gzip" + "encoding/json" + "fmt" + "net" + "os" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/crowdsecurity/crowdsec/pkg/types" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "gopkg.in/tomb.v2" +) + +func getLocalStackEndpoint() (string, error) { + endpoint := "http://localhost:4566" + if v := os.Getenv("AWS_ENDPOINT_FORCE"); v != "" { + v = strings.TrimPrefix(v, "http://") + _, err := net.Dial("tcp", v) + if err != nil { + return "", fmt.Errorf("while dialing %s : %s : aws endpoint isn't available", v, err) + } + } + return endpoint, nil +} + +func GenSubObject(i int) []byte { + r := CloudWatchSubscriptionRecord{ + MessageType: "subscription", + Owner: "test", + LogGroup: "test", + LogStream: "test", + SubscriptionFilters: []string{"filter1"}, + LogEvents: []CloudwatchSubscriptionLogEvent{ + { + ID: "testid", + Message: fmt.Sprintf("%d", i), + Timestamp: time.Now().Unix(), + }, + }, + } + body, err := json.Marshal(r) + if err != nil { + log.Fatal(err) + } + var b bytes.Buffer + gz := gzip.NewWriter(&b) + gz.Write(body) + gz.Close() + //AWS actually base64 encodes the data, but it looks like kinesis automatically decodes it at some point + //localstack does not do it, so let's just write a raw gzipped stream + return b.Bytes() +} + +func WriteToStream(streamName string, count int, shards int, sub bool) { + endpoint, err := getLocalStackEndpoint() + if err != nil { + log.Fatal(err) + } + sess := session.Must(session.NewSession()) + kinesisClient := kinesis.New(sess, aws.NewConfig().WithEndpoint(endpoint).WithRegion("us-east-1")) + for i := 0; i < count; i++ { + partition := "partition" + if shards != 1 { + partition = fmt.Sprintf("partition-%d", i%shards) + } + var data []byte + if sub { + data = GenSubObject(i) + } else { + data = []byte(fmt.Sprintf("%d", i)) + } + _, err = kinesisClient.PutRecord(&kinesis.PutRecordInput{ + Data: data, + PartitionKey: aws.String(partition), + StreamName: aws.String(streamName), + }) + if err != nil { + fmt.Printf("Error writing to stream: %s\n", err) + log.Fatal(err) + } + } +} + +func TestMain(m *testing.M) { + os.Setenv("AWS_ACCESS_KEY_ID", "foobar") + os.Setenv("AWS_SECRET_ACCESS_KEY", "foobar") + + //delete_streams() + //create_streams() + code := m.Run() + //delete_streams() + os.Exit(code) +} + +func TestBadConfiguration(t *testing.T) { + tests := []struct { + config string + expectedErr string + }{ + { + config: `source: kinesis`, + expectedErr: "stream_name is mandatory when use_enhanced_fanout is false", + }, + { + config: ` +source: kinesis +use_enhanced_fanout: true`, + expectedErr: "stream_arn is mandatory when use_enhanced_fanout is true", + }, + { + config: ` +source: kinesis +use_enhanced_fanout: true +stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, + expectedErr: "consumer_name is mandatory when use_enhanced_fanout is true", + }, + { + config: ` +source: kinesis +stream_name: foobar +stream_arn: arn:aws:kinesis:eu-west-1:123456789012:stream/my-stream`, + expectedErr: "stream_arn and stream_name are mutually exclusive", + }, + } + + subLogger := log.WithFields(log.Fields{ + "type": "kinesis", + }) + for _, test := range tests { + f := KinesisSource{} + err := f.Configure([]byte(test.config), subLogger) + if test.expectedErr != "" && err == nil { + t.Fatalf("Expected err %s but got nil !", test.expectedErr) + } + if test.expectedErr != "" { + assert.Contains(t, err.Error(), test.expectedErr) + } + } +} + +func TestReadFromStream(t *testing.T) { + tests := []struct { + config string + count int + shards int + }{ + { + config: `source: kinesis +aws_endpoint: %s +aws_region: us-east-1 +stream_name: stream-1-shard`, + count: 10, + shards: 1, + }, + } + endpoint, _ := getLocalStackEndpoint() + for _, test := range tests { + f := KinesisSource{} + config := fmt.Sprintf(test.config, endpoint) + err := f.Configure([]byte(config), log.WithFields(log.Fields{ + "type": "kinesis", + })) + if err != nil { + t.Fatalf("Error configuring source: %s", err) + } + tomb := &tomb.Tomb{} + out := make(chan types.Event) + err = f.StreamingAcquisition(out, tomb) + if err != nil { + t.Fatalf("Error starting source: %s", err) + } + //Allow the datasource to start listening to the stream + time.Sleep(4 * time.Second) + WriteToStream(f.Config.StreamName, test.count, test.shards, false) + for i := 0; i < test.count; i++ { + e := <-out + assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw) + } + tomb.Kill(nil) + tomb.Wait() + } +} + +func TestReadFromMultipleShards(t *testing.T) { + tests := []struct { + config string + count int + shards int + }{ + { + config: `source: kinesis +aws_endpoint: %s +aws_region: us-east-1 +stream_name: stream-2-shards`, + count: 10, + shards: 2, + }, + } + endpoint, _ := getLocalStackEndpoint() + for _, test := range tests { + f := KinesisSource{} + config := fmt.Sprintf(test.config, endpoint) + err := f.Configure([]byte(config), log.WithFields(log.Fields{ + "type": "kinesis", + })) + if err != nil { + t.Fatalf("Error configuring source: %s", err) + } + tomb := &tomb.Tomb{} + out := make(chan types.Event) + err = f.StreamingAcquisition(out, tomb) + if err != nil { + t.Fatalf("Error starting source: %s", err) + } + //Allow the datasource to start listening to the stream + time.Sleep(4 * time.Second) + WriteToStream(f.Config.StreamName, test.count, test.shards, false) + c := 0 + for i := 0; i < test.count; i++ { + <-out + c += 1 + } + assert.Equal(t, test.count, c) + tomb.Kill(nil) + tomb.Wait() + } +} + +func TestFromSubscription(t *testing.T) { + tests := []struct { + config string + count int + shards int + }{ + { + config: `source: kinesis +aws_endpoint: %s +aws_region: us-east-1 +stream_name: stream-1-shard +from_subscription: true`, + count: 10, + shards: 1, + }, + } + endpoint, _ := getLocalStackEndpoint() + for _, test := range tests { + f := KinesisSource{} + config := fmt.Sprintf(test.config, endpoint) + err := f.Configure([]byte(config), log.WithFields(log.Fields{ + "type": "kinesis", + })) + if err != nil { + t.Fatalf("Error configuring source: %s", err) + } + tomb := &tomb.Tomb{} + out := make(chan types.Event) + err = f.StreamingAcquisition(out, tomb) + if err != nil { + t.Fatalf("Error starting source: %s", err) + } + //Allow the datasource to start listening to the stream + time.Sleep(4 * time.Second) + WriteToStream(f.Config.StreamName, test.count, test.shards, true) + for i := 0; i < test.count; i++ { + e := <-out + assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw) + } + tomb.Kill(nil) + tomb.Wait() + } +} + +/* +func TestSubscribeToStream(t *testing.T) { + tests := []struct { + config string + count int + shards int + }{ + { + config: `source: kinesis +aws_endpoint: %s +aws_region: us-east-1 +stream_arn: arn:aws:kinesis:us-east-1:000000000000:stream/stream-1-shard +consumer_name: consumer-1 +use_enhanced_fanout: true`, + count: 10, + shards: 1, + }, + } + endpoint, _ := getLocalStackEndpoint() + for _, test := range tests { + f := KinesisSource{} + config := fmt.Sprintf(test.config, endpoint) + err := f.Configure([]byte(config), log.WithFields(log.Fields{ + "type": "kinesis", + })) + if err != nil { + t.Fatalf("Error configuring source: %s", err) + } + tomb := &tomb.Tomb{} + out := make(chan types.Event) + err = f.StreamingAcquisition(out, tomb) + if err != nil { + t.Fatalf("Error starting source: %s", err) + } + //Allow the datasource to start listening to the stream + time.Sleep(10 * time.Second) + WriteToStream("stream-1-shard", test.count, test.shards) + for i := 0; i < test.count; i++ { + e := <-out + assert.Equal(t, fmt.Sprintf("%d", i), e.Line.Raw) + } + } +} +*/