From 98f2ac5e7c1c9811579187cba0049b327fc89ea6 Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Wed, 18 May 2022 10:08:37 +0200 Subject: [PATCH] fix #1385: .yaml.local (#1497) Added support for .yaml.local files to override values in .yaml --- pkg/csconfig/api.go | 4 +- pkg/csconfig/config.go | 10 +- pkg/csconfig/config_test.go | 14 +- pkg/csconfig/profiles.go | 11 +- pkg/csconfig/simulation.go | 9 +- pkg/csconfig/simulation_test.go | 7 +- pkg/cstest/utils.go | 13 +- pkg/yamlpatch/merge.go | 168 +++++++++++++ pkg/yamlpatch/merge_test.go | 235 +++++++++++++++++++ pkg/yamlpatch/patcher.go | 153 ++++++++++++ pkg/yamlpatch/patcher_test.go | 313 +++++++++++++++++++++++++ pkg/yamlpatch/testdata/base.yaml | 13 + pkg/yamlpatch/testdata/expect.yaml | 13 + pkg/yamlpatch/testdata/production.yaml | 13 + tests/bats/05_config_yaml_local.bats | 134 +++++++++++ tests/bats/40_live-ban.bats | 1 - tests/lib/util/wait-for-port | 19 +- 17 files changed, 1088 insertions(+), 42 deletions(-) create mode 100644 pkg/yamlpatch/merge.go create mode 100644 pkg/yamlpatch/merge_test.go create mode 100644 pkg/yamlpatch/patcher.go create mode 100644 pkg/yamlpatch/patcher_test.go create mode 100644 pkg/yamlpatch/testdata/base.yaml create mode 100644 pkg/yamlpatch/testdata/expect.yaml create mode 100644 pkg/yamlpatch/testdata/production.yaml create mode 100644 tests/bats/05_config_yaml_local.bats diff --git a/pkg/csconfig/api.go b/pkg/csconfig/api.go index 1c7686093..b77b5a378 100644 --- a/pkg/csconfig/api.go +++ b/pkg/csconfig/api.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/crowdsecurity/crowdsec/pkg/apiclient" + "github.com/crowdsecurity/crowdsec/pkg/yamlpatch" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -54,7 +55,8 @@ func (o *OnlineApiClientCfg) Load() error { } func (l *LocalApiClientCfg) Load() error { - fcontent, err := ioutil.ReadFile(l.CredentialsFilePath) + patcher := yamlpatch.NewPatcher(l.CredentialsFilePath, ".local") + fcontent, err := patcher.MergedPatchContent() if err != nil { return errors.Wrapf(err, "failed to read api client credential configuration file '%s'", l.CredentialsFilePath) } diff --git a/pkg/csconfig/config.go b/pkg/csconfig/config.go index e0ea65843..d03a89b26 100644 --- a/pkg/csconfig/config.go +++ b/pkg/csconfig/config.go @@ -2,11 +2,11 @@ package csconfig import ( "fmt" - "io/ioutil" "os" "path/filepath" "github.com/crowdsecurity/crowdsec/pkg/types" + "github.com/crowdsecurity/crowdsec/pkg/yamlpatch" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -46,9 +46,10 @@ func (c *Config) Dump() error { } func NewConfig(configFile string, disableAgent bool, disableAPI bool) (*Config, error) { - fcontent, err := ioutil.ReadFile(configFile) + patcher := yamlpatch.NewPatcher(configFile, ".local") + fcontent, err := patcher.MergedPatchContent() if err != nil { - return nil, errors.Wrap(err, "failed to read config file") + return nil, err } configData := os.ExpandEnv(string(fcontent)) cfg := Config{ @@ -59,7 +60,8 @@ func NewConfig(configFile string, disableAgent bool, disableAPI bool) (*Config, err = yaml.UnmarshalStrict([]byte(configData), &cfg) if err != nil { - return nil, err + // this is actually the "merged" yaml + return nil, errors.Wrap(err, configFile) } return &cfg, nil } diff --git a/pkg/csconfig/config_test.go b/pkg/csconfig/config_test.go index c0e2b7f19..cfa16726d 100644 --- a/pkg/csconfig/config_test.go +++ b/pkg/csconfig/config_test.go @@ -4,7 +4,6 @@ import ( "fmt" "log" "runtime" - "strings" "testing" "github.com/stretchr/testify/assert" @@ -19,20 +18,13 @@ func TestNormalLoad(t *testing.T) { _, err = NewConfig("./tests/xxx.yaml", false, false) if runtime.GOOS != "windows" { - if fmt.Sprintf("%s", err) != "failed to read config file: open ./tests/xxx.yaml: no such file or directory" { - t.Fatalf("unexpected error %s", err) - } + assert.EqualError(t, err, "while reading ./tests/xxx.yaml: open ./tests/xxx.yaml: no such file or directory") } else { - if fmt.Sprintf("%s", err) != "failed to read config file: open ./tests/xxx.yaml: The system cannot find the file specified." { - t.Fatalf("unexpected error %s", err) - } + assert.EqualError(t, err, "while reading ./tests/xxx.yaml: open ./tests/xxx.yaml: The system cannot find the file specified.") } _, err = NewConfig("./tests/simulation.yaml", false, false) - if !strings.HasPrefix(fmt.Sprintf("%s", err), "yaml: unmarshal errors:") { - t.Fatalf("unexpected error %s", err) - } - + assert.EqualError(t, err, "./tests/simulation.yaml: yaml: unmarshal errors:\n line 1: field simulation not found in type csconfig.Config") } func TestNewCrowdSecConfig(t *testing.T) { diff --git a/pkg/csconfig/profiles.go b/pkg/csconfig/profiles.go index 23757d171..e2a1bbb37 100644 --- a/pkg/csconfig/profiles.go +++ b/pkg/csconfig/profiles.go @@ -1,15 +1,16 @@ package csconfig import ( + "bytes" "fmt" "io" - "os" "time" "github.com/antonmedv/expr" "github.com/antonmedv/expr/vm" "github.com/crowdsecurity/crowdsec/pkg/exprhelpers" "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/yamlpatch" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -33,13 +34,15 @@ func (c *LocalApiServerCfg) LoadProfiles() error { return fmt.Errorf("empty profiles path") } - yamlFile, err := os.Open(c.ProfilesPath) + patcher := yamlpatch.NewPatcher(c.ProfilesPath, ".local") + fcontent, err := patcher.PrependedPatchContent() if err != nil { - return errors.Wrapf(err, "while opening %s", c.ProfilesPath) + return err } + reader := bytes.NewReader(fcontent) //process the yaml - dec := yaml.NewDecoder(yamlFile) + dec := yaml.NewDecoder(reader) dec.SetStrict(true) for { t := ProfileCfg{} diff --git a/pkg/csconfig/simulation.go b/pkg/csconfig/simulation.go index 4325045c8..69c520c5c 100644 --- a/pkg/csconfig/simulation.go +++ b/pkg/csconfig/simulation.go @@ -2,10 +2,9 @@ package csconfig import ( "fmt" - "io/ioutil" "path/filepath" - "github.com/pkg/errors" + "github.com/crowdsecurity/crowdsec/pkg/yamlpatch" "gopkg.in/yaml.v2" ) @@ -39,9 +38,11 @@ func (c *Config) LoadSimulation() error { if c.ConfigPaths.SimulationFilePath == "" { c.ConfigPaths.SimulationFilePath = filepath.Clean(c.ConfigPaths.ConfigDir + "/simulation.yaml") } - rcfg, err := ioutil.ReadFile(c.ConfigPaths.SimulationFilePath) + + patcher := yamlpatch.NewPatcher(c.ConfigPaths.SimulationFilePath, ".local") + rcfg, err := patcher.MergedPatchContent() if err != nil { - return errors.Wrapf(err, "while reading '%s'", c.ConfigPaths.SimulationFilePath) + return err } if err := yaml.UnmarshalStrict(rcfg, &simCfg); err != nil { return fmt.Errorf("while unmarshaling simulation file '%s' : %s", c.ConfigPaths.SimulationFilePath, err) diff --git a/pkg/csconfig/simulation_test.go b/pkg/csconfig/simulation_test.go index b02011f63..080d3212a 100644 --- a/pkg/csconfig/simulation_test.go +++ b/pkg/csconfig/simulation_test.go @@ -89,7 +89,7 @@ func TestSimulationLoading(t *testing.T) { }, Crowdsec: &CrowdsecServiceCfg{}, }, - err: fmt.Sprintf("while reading '%s': open %s: The system cannot find the file specified.", testXXFullPath, testXXFullPath), + err: fmt.Sprintf("while reading %s: open %s: The system cannot find the file specified.", testXXFullPath, testXXFullPath), }) } else { tests = append(tests, struct { @@ -106,7 +106,7 @@ func TestSimulationLoading(t *testing.T) { }, Crowdsec: &CrowdsecServiceCfg{}, }, - err: fmt.Sprintf("while reading '%s': open %s: no such file or directory", testXXFullPath, testXXFullPath), + err: fmt.Sprintf("while reading %s: open %s: no such file or directory", testXXFullPath, testXXFullPath), }) } @@ -115,7 +115,8 @@ func TestSimulationLoading(t *testing.T) { if err == nil && test.err != "" { fmt.Printf("TEST '%s': NOK\n", test.name) t.Fatalf("%d/%d expected error, didn't get it", idx, len(tests)) - } else if test.err != "" { + } + if test.err != "" { if !strings.HasPrefix(fmt.Sprintf("%s", err), test.err) { fmt.Printf("TEST '%s': NOK\n", test.name) t.Fatalf("%d/%d expected '%s' got '%s'", idx, len(tests), diff --git a/pkg/cstest/utils.go b/pkg/cstest/utils.go index ea44820d9..00be61ceb 100644 --- a/pkg/cstest/utils.go +++ b/pkg/cstest/utils.go @@ -111,15 +111,10 @@ func CopyDir(src string, dest string) error { } func AssertErrorContains(t *testing.T, err error, expectedErr string) { - if expectedErr == "" { - if err != nil { - t.Fatalf("Unexpected error: %s", err) - } - assert.Equal(t, err, nil) + if expectedErr != "" { + assert.ErrorContains(t, err, expectedErr) return } - if err == nil { - t.Fatalf("Expected '%s', got nil", expectedErr) - } - assert.Contains(t, err.Error(), expectedErr) + + assert.NoError(t, err) } diff --git a/pkg/yamlpatch/merge.go b/pkg/yamlpatch/merge.go new file mode 100644 index 000000000..8a61b6470 --- /dev/null +++ b/pkg/yamlpatch/merge.go @@ -0,0 +1,168 @@ +// +// from https://github.com/uber-go/config/tree/master/internal/merge +// +// Copyright (c) 2019 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package yamlpatch + +import ( + "bytes" + "fmt" + "io" + + "github.com/pkg/errors" + + yaml "gopkg.in/yaml.v2" +) + +type ( + // YAML has three fundamental types. When unmarshaled into interface{}, + // they're represented like this. + mapping = map[interface{}]interface{} + sequence = []interface{} +) + +// YAML deep-merges any number of YAML sources, with later sources taking +// priority over earlier ones. +// +// Maps are deep-merged. For example, +// {"one": 1, "two": 2} + {"one": 42, "three": 3} +// == {"one": 42, "two": 2, "three": 3} +// Sequences are replaced. For example, +// {"foo": [1, 2, 3]} + {"foo": [4, 5, 6]} +// == {"foo": [4, 5, 6]} +// +// In non-strict mode, duplicate map keys are allowed within a single source, +// with later values overwriting previous ones. Attempting to merge +// mismatched types (e.g., merging a sequence into a map) replaces the old +// value with the new. +// +// Enabling strict mode returns errors in both of the above cases. +func YAML(sources [][]byte, strict bool) (*bytes.Buffer, error) { + var merged interface{} + var hasContent bool + for _, r := range sources { + d := yaml.NewDecoder(bytes.NewReader(r)) + d.SetStrict(strict) + + var contents interface{} + if err := d.Decode(&contents); err == io.EOF { + // Skip empty and comment-only sources, which we should handle + // differently from explicit nils. + continue + } else if err != nil { + return nil, fmt.Errorf("couldn't decode source: %v", err) + } + + hasContent = true + pair, err := merge(merged, contents, strict) + if err != nil { + return nil, err // error is already descriptive enough + } + merged = pair + } + + buf := &bytes.Buffer{} + if !hasContent { + // No sources had any content. To distinguish this from a source with just + // an explicit top-level null, return an empty buffer. + return buf, nil + } + enc := yaml.NewEncoder(buf) + if err := enc.Encode(merged); err != nil { + return nil, errors.Wrap(err, "couldn't re-serialize merged YAML") + } + return buf, nil +} + +func merge(into, from interface{}, strict bool) (interface{}, error) { + // It's possible to handle this with a mass of reflection, but we only need + // to merge whole YAML files. Since we're always unmarshaling into + // interface{}, we only need to handle a few types. This ends up being + // cleaner if we just handle each case explicitly. + if into == nil { + return from, nil + } + if from == nil { + // Allow higher-priority YAML to explicitly nil out lower-priority entries. + return nil, nil + } + if IsScalar(into) && IsScalar(from) { + return from, nil + } + if IsSequence(into) && IsSequence(from) { + return from, nil + } + if IsMapping(into) && IsMapping(from) { + return mergeMapping(into.(mapping), from.(mapping), strict) + } + // YAML types don't match, so no merge is possible. For backward + // compatibility, ignore mismatches unless we're in strict mode and return + // the higher-priority value. + if !strict { + return from, nil + } + return nil, fmt.Errorf("can't merge a %s into a %s", describe(from), describe(into)) +} + +func mergeMapping(into, from mapping, strict bool) (mapping, error) { + merged := make(mapping, len(into)) + for k, v := range into { + merged[k] = v + } + for k := range from { + m, err := merge(merged[k], from[k], strict) + if err != nil { + return nil, err + } + merged[k] = m + } + return merged, nil +} + +// IsMapping reports whether a type is a mapping in YAML, represented as a +// map[interface{}]interface{}. +func IsMapping(i interface{}) bool { + _, is := i.(mapping) + return is +} + +// IsSequence reports whether a type is a sequence in YAML, represented as an +// []interface{}. +func IsSequence(i interface{}) bool { + _, is := i.(sequence) + return is +} + +// IsScalar reports whether a type is a scalar value in YAML. +func IsScalar(i interface{}) bool { + return !IsMapping(i) && !IsSequence(i) +} + +func describe(i interface{}) string { + if IsMapping(i) { + return "mapping" + } + if IsSequence(i) { + return "sequence" + } + return "scalar" +} diff --git a/pkg/yamlpatch/merge_test.go b/pkg/yamlpatch/merge_test.go new file mode 100644 index 000000000..40a67934d --- /dev/null +++ b/pkg/yamlpatch/merge_test.go @@ -0,0 +1,235 @@ +// Copyright (c) 2018 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package yamlpatch + +import ( + "bytes" + "io/ioutil" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + yaml "gopkg.in/yaml.v2" +) + +func trimcr(s string) string { + return strings.ReplaceAll(s, "\r\n", "\n") +} + +func mustRead(t testing.TB, fname string) []byte { + contents, err := ioutil.ReadFile(fname) + require.NoError(t, err, "failed to read file: %s", fname) + return contents +} + +func dump(t testing.TB, actual, expected string) { + // It's impossible to debug YAML if the actual and expected values are + // printed on a single line. + t.Logf("Actual:\n\n%s\n\n", actual) + t.Logf("Expected:\n\n%s\n\n", expected) +} + +func strip(s string) string { + // It's difficult to write string constants that are valid YAML. Normalize + // strings for ease of testing. + s = strings.TrimSpace(s) + s = strings.Replace(s, "\t", " ", -1) + return s +} + +func canonicalize(t testing.TB, s string) string { + // round-trip to canonicalize formatting + var i interface{} + require.NoError(t, + yaml.Unmarshal([]byte(strip(s)), &i), + "canonicalize: couldn't unmarshal YAML", + ) + formatted, err := yaml.Marshal(i) + require.NoError(t, err, "canonicalize: couldn't marshal YAML") + return string(bytes.TrimSpace(formatted)) +} + +func unmarshal(t testing.TB, s string) interface{} { + var i interface{} + require.NoError(t, yaml.Unmarshal([]byte(strip(s)), &i), "unmarshaling failed") + return i +} + +func succeeds(t testing.TB, strict bool, left, right, expect string) { + l, r := unmarshal(t, left), unmarshal(t, right) + m, err := merge(l, r, strict) + require.NoError(t, err, "merge failed") + + actualBytes, err := yaml.Marshal(m) + require.NoError(t, err, "couldn't marshal merged structure") + actual := canonicalize(t, string(actualBytes)) + expect = canonicalize(t, expect) + if !assert.Equal(t, expect, actual) { + dump(t, actual, expect) + } +} + +func fails(t testing.TB, strict bool, left, right string) { + _, err := merge(unmarshal(t, left), unmarshal(t, right), strict) + assert.Error(t, err, "merge succeeded") +} + +func TestIntegration(t *testing.T) { + base := mustRead(t, "testdata/base.yaml") + prod := mustRead(t, "testdata/production.yaml") + expect := mustRead(t, "testdata/expect.yaml") + + merged, err := YAML([][]byte{base, prod}, true /* strict */) + require.NoError(t, err, "merge failed") + + if !assert.Equal(t, trimcr(string(expect)), merged.String(), "unexpected contents") { + dump(t, merged.String(), string(expect)) + } +} + +func TestEmpty(t *testing.T) { + full := []byte("foo: bar\n") + null := []byte("~") + + tests := []struct { + desc string + sources [][]byte + expect string + }{ + {"empty base", [][]byte{nil, full}, string(full)}, + {"empty override", [][]byte{full, nil}, string(full)}, + {"both empty", [][]byte{nil, nil}, ""}, + {"null base", [][]byte{null, full}, string(full)}, + {"null override", [][]byte{full, null}, "null\n"}, + {"empty base and null override", [][]byte{nil, null}, "null\n"}, + {"null base and empty override", [][]byte{null, nil}, "null\n"}, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + merged, err := YAML(tt.sources, true /* strict */) + require.NoError(t, err, "merge failed") + assert.Equal(t, tt.expect, merged.String(), "wrong contents after merge") + }) + } +} + +func TestSuccess(t *testing.T) { + left := ` +fun: [maserati, porsche] +practical: {toyota: camry, honda: accord} +occupants: + honda: {driver: jane, backseat: [nate]} + ` + right := ` +fun: [lamborghini, porsche] +practical: {honda: civic, nissan: altima} +occupants: + honda: {passenger: arthur, backseat: [nora]} + ` + expect := ` +fun: [lamborghini, porsche] +practical: {toyota: camry, honda: civic, nissan: altima} +occupants: + honda: {passenger: arthur, driver: jane, backseat: [nora]} + ` + succeeds(t, true, left, right, expect) + succeeds(t, false, left, right, expect) +} + +func TestErrors(t *testing.T) { + check := func(t testing.TB, strict bool, sources ...[]byte) error { + _, err := YAML(sources, strict) + return err + } + t.Run("tabs in source", func(t *testing.T) { + src := []byte("foo:\n\tbar:baz") + assert.Error(t, check(t, false, src), "expected error in permissive mode") + assert.Error(t, check(t, true, src), "expected error in strict mode") + }) + + t.Run("duplicated keys", func(t *testing.T) { + src := []byte("{foo: bar, foo: baz}") + assert.NoError(t, check(t, false, src), "expected success in permissive mode") + assert.Error(t, check(t, true, src), "expected error in permissive mode") + }) + + t.Run("merge error", func(t *testing.T) { + left := []byte("foo: [1, 2]") + right := []byte("foo: {bar: baz}") + assert.NoError(t, check(t, false, left, right), "expected success in permissive mode") + assert.Error(t, check(t, true, left, right), "expected error in strict mode") + }) +} + +func TestMismatchedTypes(t *testing.T) { + tests := []struct { + desc string + left, right string + }{ + {"sequence and mapping", "[one, two]", "{foo: bar}"}, + {"sequence and scalar", "[one, two]", "foo"}, + {"mapping and scalar", "{foo: bar}", "foo"}, + {"nested", "{foo: [one, two]}", "{foo: bar}"}, + } + + for _, tt := range tests { + t.Run(tt.desc+" strict", func(t *testing.T) { + fails(t, true, tt.left, tt.right) + }) + t.Run(tt.desc+" permissive", func(t *testing.T) { + // prefer the higher-priority value + succeeds(t, false, tt.left, tt.right, tt.right) + }) + } +} + +func TestBooleans(t *testing.T) { + // YAML helpfully interprets many strings as Booleans. + tests := []struct { + in, out string + }{ + {"yes", "true"}, + {"YES", "true"}, + {"on", "true"}, + {"ON", "true"}, + {"no", "false"}, + {"NO", "false"}, + {"off", "false"}, + {"OFF", "false"}, + } + + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + succeeds(t, true, "", tt.in, tt.out) + succeeds(t, false, "", tt.in, tt.out) + }) + } +} + +func TestExplicitNil(t *testing.T) { + base := `foo: {one: two}` + override := `foo: ~` + expect := `foo: ~` + succeeds(t, true, base, override, expect) + succeeds(t, false, base, override, expect) +} diff --git a/pkg/yamlpatch/patcher.go b/pkg/yamlpatch/patcher.go new file mode 100644 index 000000000..2da696c91 --- /dev/null +++ b/pkg/yamlpatch/patcher.go @@ -0,0 +1,153 @@ +package yamlpatch + +import ( + "bytes" + "io" + "os" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "gopkg.in/yaml.v2" +) + +type Patcher struct { + BaseFilePath string + PatchFilePath string +} + +func NewPatcher(filePath string, suffix string) *Patcher { + return &Patcher{ + BaseFilePath: filePath, + PatchFilePath: filePath + suffix, + } +} + +// read a single YAML file, check for errors (the merge package doesn't) then return the content as bytes. +func readYAML(filePath string) ([]byte, error) { + var content []byte + + var err error + + if content, err = os.ReadFile(filePath); err != nil { + return nil, errors.Wrapf(err, "while reading %s", filePath) + } + + var yamlMap map[interface{}]interface{} + if err = yaml.Unmarshal(content, &yamlMap); err != nil { + return nil, errors.Wrap(err, filePath) + } + + return content, nil +} + +// MergedPatchContent reads a YAML file and, if it exists, its patch file, +// then merges them and returns it serialized. +func (p *Patcher) MergedPatchContent() ([]byte, error) { + var err error + + var base []byte + + base, err = readYAML(p.BaseFilePath) + if err != nil { + return nil, err + } + + var over []byte + + over, err = readYAML(p.PatchFilePath) + // optional file, ignore if it does not exist + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, err + } + if err == nil { + log.Debugf("Patching yaml: '%s' with '%s'", p.BaseFilePath, p.PatchFilePath) + } + + var patched *bytes.Buffer + + // strict mode true, will raise errors for duplicate map keys and + // overriding with a different type + patched, err = YAML([][]byte{base, over}, true) + if err != nil { + return nil, err + } + + return patched.Bytes(), nil +} + +// read multiple YAML documents inside a file, and writes them to a buffer +// separated by the appropriate '---' terminators. +func decodeDocuments(file *os.File, buf *bytes.Buffer, finalDashes bool) error { + var ( + err error + docBytes []byte + ) + + dec := yaml.NewDecoder(file) + dec.SetStrict(true) + + dashTerminator := false + + for { + yml := make(map[interface{}]interface{}) + + err = dec.Decode(&yml) + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return errors.Wrapf(err, "while decoding %s", file.Name()) + } + + docBytes, err = yaml.Marshal(&yml) + if err != nil { + return errors.Wrapf(err, "while marshaling %s", file.Name()) + } + + if dashTerminator { + buf.Write([]byte("---\n")) + } + + buf.Write(docBytes) + dashTerminator = true + } + if dashTerminator && finalDashes { + buf.Write([]byte("---\n")) + } + return nil +} + +// PrependedPatchContent collates the base .yaml file with the .yaml.patch, by putting +// the content of the patch BEFORE the base document. The result is a multi-document +// YAML in all cases, even if the base and patch files are single documents. +func (p *Patcher) PrependedPatchContent() ([]byte, error) { + var ( + result bytes.Buffer + patchFile *os.File + baseFile *os.File + err error + ) + + patchFile, err = os.Open(p.PatchFilePath) + // optional file, ignore if it does not exist + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, errors.Wrapf(err, "while opening %s", p.PatchFilePath) + } + + if patchFile != nil { + if err = decodeDocuments(patchFile, &result, true); err != nil { + return nil, err + } + } + + baseFile, err = os.Open(p.BaseFilePath) + if err != nil { + return nil, errors.Wrapf(err, "while opening %s", p.BaseFilePath) + } + + if err = decodeDocuments(baseFile, &result, false); err != nil { + return nil, err + } + + return result.Bytes(), nil +} diff --git a/pkg/yamlpatch/patcher_test.go b/pkg/yamlpatch/patcher_test.go new file mode 100644 index 000000000..be4a855cf --- /dev/null +++ b/pkg/yamlpatch/patcher_test.go @@ -0,0 +1,313 @@ +package yamlpatch_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/crowdsecurity/crowdsec/pkg/yamlpatch" + "github.com/stretchr/testify/require" +) + +// similar to the one in cstest, but with test number too. We cannot import +// cstest here because of circular dependency. +func requireErrorContains(t *testing.T, err error, expectedErr string) { + t.Helper() + + if expectedErr != "" { + require.ErrorContains(t, err, expectedErr) + + return + } + + require.NoError(t, err) +} + +func TestMergedPatchContent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + patch string + expected string + expectedErr string + }{ + { + "invalid yaml in base", + "notayaml", + "", + "", + "config.yaml: yaml: unmarshal errors:", + }, + { + "invalid yaml in base (detailed message)", + "notayaml", + "", + "", + "cannot unmarshal !!str `notayaml`", + }, + { + "invalid yaml in patch", + "", + "notayaml", + "", + "config.yaml.local: yaml: unmarshal errors:", + }, + { + "invalid yaml in patch (detailed message)", + "", + "notayaml", + "", + "cannot unmarshal !!str `notayaml`", + }, + { + "basic merge", + "{'first':{'one':1,'two':2},'second':{'three':3}}", + "{'first':{'one':10,'dos':2}}", + "{'first':{'one':10,'dos':2,'two':2},'second':{'three':3}}", + "", + }, + + // bools and zero values; here the "mergo" package had issues + // so we used something simpler. + + { + "bool merge - off if false", + "bool: on", + "bool: off", + "bool: false", + "", + }, + { + "bool merge - on is true", + "bool: off", + "bool: on", + "bool: true", + "", + }, + { + "string is not a bool - on to off", + "{'bool': 'on'}", + "{'bool': 'off'}", + "{'bool': 'off'}", + "", + }, + { + "string is not a bool - off to on", + "{'bool': 'off'}", + "{'bool': 'on'}", + "{'bool': 'on'}", + "", + }, + { + "bool merge - true to false", + "{'bool': true}", + "{'bool': false}", + "{'bool': false}", + "", + }, + { + "bool merge - false to true", + "{'bool': false}", + "{'bool': true}", + "{'bool': true}", + "", + }, + { + "string merge - value to value", + "{'string': 'value'}", + "{'string': ''}", + "{'string': ''}", + "", + }, + { + "sequence merge - value to empty", + "{'sequence': [1, 2]}", + "{'sequence': []}", + "{'sequence': []}", + "", + }, + { + "map merge - value to value", + "{'map': {'one': 1, 'two': 2}}", + "{'map': {}}", + "{'map': {'one': 1, 'two': 2}}", + "", + }, + + // mismatched types + + { + "can't merge a sequence into a mapping", + "map: {'key': 'value'}", + "map: ['value1', 'value2']", + "", + "can't merge a sequence into a mapping", + }, + { + "can't merge a scalar into a mapping", + "map: {'key': 'value'}", + "map: 3", + "", + "can't merge a scalar into a mapping", + }, + { + "can't merge a mapping into a sequence", + "sequence: ['value1', 'value2']", + "sequence: {'key': 'value'}", + "", + "can't merge a mapping into a sequence", + }, + { + "can't merge a scalar into a sequence", + "sequence: ['value1', 'value2']", + "sequence: 3", + "", + "can't merge a scalar into a sequence", + }, + { + "can't merge a sequence into a scalar", + "scalar: true", + "scalar: ['value1', 'value2']", + "", + "can't merge a sequence into a scalar", + }, + { + "can't merge a mapping into a scalar", + "scalar: true", + "scalar: {'key': 'value'}", + "", + "can't merge a mapping into a scalar", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + dirPath, err := os.MkdirTemp("", "yamlpatch") + require.NoError(t, err) + + defer os.RemoveAll(dirPath) + + configPath := filepath.Join(dirPath, "config.yaml") + patchPath := filepath.Join(dirPath, "config.yaml.local") + err = os.WriteFile(configPath, []byte(tc.base), 0o600) + require.NoError(t, err) + + err = os.WriteFile(patchPath, []byte(tc.patch), 0o600) + require.NoError(t, err) + + patcher := yamlpatch.NewPatcher(configPath, ".local") + patchedBytes, err := patcher.MergedPatchContent() + requireErrorContains(t, err, tc.expectedErr) + require.YAMLEq(t, tc.expected, string(patchedBytes)) + }) + } +} + +func TestPrependedPatchContent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + base string + patch string + expected string + expectedErr string + }{ + // we test with scalars here, because YAMLeq does not work + // with multi-document files, so we need char-to-char comparison + // which is noisy with sequences and (unordered) mappings + { + "newlines are always appended, if missing, by yaml.Marshal()", + "foo: bar", + "", + "foo: bar\n", + "", + }, + { + "prepend empty document", + "foo: bar\n", + "", + "foo: bar\n", + "", + }, + { + "prepend a document to another", + "foo: bar", + "baz: qux", + "baz: qux\n---\nfoo: bar\n", + "", + }, + { + "prepend document with same key", + "foo: true", + "foo: false", + "foo: false\n---\nfoo: true\n", + "", + }, + { + "prepend multiple documents", + "one: 1\n---\ntwo: 2\n---\none: 3", + "four: 4\n---\none: 1.1", + "four: 4\n---\none: 1.1\n---\none: 1\n---\ntwo: 2\n---\none: 3\n", + "", + }, + { + "invalid yaml in base", + "blablabla", + "", + "", + "config.yaml: yaml: unmarshal errors:", + }, + { + "invalid yaml in base (detailed message)", + "blablabla", + "", + "", + "cannot unmarshal !!str `blablabla`", + }, + { + "invalid yaml in patch", + "", + "blablabla", + "", + "config.yaml.local: yaml: unmarshal errors:", + }, + { + "invalid yaml in patch (detailed message)", + "", + "blablabla", + "", + "cannot unmarshal !!str `blablabla`", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + dirPath, err := os.MkdirTemp("", "yamlpatch") + require.NoError(t, err) + + defer os.RemoveAll(dirPath) + + configPath := filepath.Join(dirPath, "config.yaml") + patchPath := filepath.Join(dirPath, "config.yaml.local") + + err = os.WriteFile(configPath, []byte(tc.base), 0o600) + require.NoError(t, err) + + err = os.WriteFile(patchPath, []byte(tc.patch), 0o600) + require.NoError(t, err) + + patcher := yamlpatch.NewPatcher(configPath, ".local") + patchedBytes, err := patcher.PrependedPatchContent() + requireErrorContains(t, err, tc.expectedErr) + // YAMLeq does not handle multiple documents + require.Equal(t, tc.expected, string(patchedBytes)) + }) + } +} diff --git a/pkg/yamlpatch/testdata/base.yaml b/pkg/yamlpatch/testdata/base.yaml new file mode 100644 index 000000000..4ac551ad5 --- /dev/null +++ b/pkg/yamlpatch/testdata/base.yaml @@ -0,0 +1,13 @@ +fun: + - maserati + - porsche + +practical: + toyota: camry + honda: accord + +occupants: + honda: + driver: jane + backseat: + - nate diff --git a/pkg/yamlpatch/testdata/expect.yaml b/pkg/yamlpatch/testdata/expect.yaml new file mode 100644 index 000000000..c19091563 --- /dev/null +++ b/pkg/yamlpatch/testdata/expect.yaml @@ -0,0 +1,13 @@ +fun: +- lamborghini +- porsche +occupants: + honda: + backseat: + - nora + driver: jane + passenger: arthur +practical: + honda: civic + nissan: altima + toyota: camry diff --git a/pkg/yamlpatch/testdata/production.yaml b/pkg/yamlpatch/testdata/production.yaml new file mode 100644 index 000000000..7dab2aeee --- /dev/null +++ b/pkg/yamlpatch/testdata/production.yaml @@ -0,0 +1,13 @@ +fun: + - lamborghini + - porsche + +practical: + honda: civic + nissan: altima + +occupants: + honda: + passenger: arthur + backseat: + - nora diff --git a/tests/bats/05_config_yaml_local.bats b/tests/bats/05_config_yaml_local.bats new file mode 100644 index 000000000..1bdb95adb --- /dev/null +++ b/tests/bats/05_config_yaml_local.bats @@ -0,0 +1,134 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +fake_log() { + for _ in $(seq 1 6); do + echo "$(LC_ALL=C date '+%b %d %H:%M:%S ')"'sd-126005 sshd[12422]: Invalid user netflix from 1.1.1.172 port 35424' + done +} + +setup_file() { + load "../lib/setup_file.sh" +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + ./instance-data load +} + +teardown() { + ./instance-crowdsec stop +} + + +#---------- + +@test "${FILE} config.yaml.local - cscli (log_level)" { + yq e '.common.log_level="warning"' -i "${CONFIG_YAML}" + run -0 cscli config show --key Config.Common.LogLevel + assert_output "warning" + + echo "{'common':{'log_level':'debug'}}" > "${CONFIG_YAML}.local" + run -0 cscli config show --key Config.Common.LogLevel + assert_output "debug" +} + +@test "${FILE} config.yaml.local - cscli (log_level - with envvar)" { + yq e '.common.log_level="warning"' -i "${CONFIG_YAML}" + run -0 cscli config show --key Config.Common.LogLevel + assert_output "warning" + + export CROWDSEC_LOG_LEVEL=debug + echo "{'common':{'log_level':'${CROWDSEC_LOG_LEVEL}'}}" > "${CONFIG_YAML}.local" + run -0 cscli config show --key Config.Common.LogLevel + assert_output "debug" +} + +@test "${FILE} config.yaml.local - crowdsec (listen_url)" { + run -0 ./instance-crowdsec start + run -0 ./lib/util/wait-for-port -q 8080 + run -0 ./instance-crowdsec stop + + echo "{'api':{'server':{'listen_uri':127.0.0.1:8083}}}" > "${CONFIG_YAML}.local" + run -0 ./instance-crowdsec start + run -0 ./lib/util/wait-for-port -q 8083 + run -1 ./lib/util/wait-for-port -q 8080 + run -0 ./instance-crowdsec stop + + rm -f "${CONFIG_YAML}.local" + run -0 ./instance-crowdsec start + run -1 ./lib/util/wait-for-port -q 8083 + run -0 ./lib/util/wait-for-port -q 8080 +} + +@test "${FILE} local_api_credentials.yaml.local" { + echo "{'api':{'server':{'listen_uri':127.0.0.1:8083}}}" > "${CONFIG_YAML}.local" + run -0 ./instance-crowdsec start + run -0 nc -z localhost 8083 + + run -0 yq e '.api.client.credentials_path' < "${CONFIG_YAML}" + LOCAL_API_CREDENTIALS="${output}" + + run -1 cscli decisions list + echo "{'url':'http://127.0.0.1:8083'}" > "${LOCAL_API_CREDENTIALS}.local" + run -0 cscli decisions list +} + +@test "${FILE} simulation.yaml.local" { + run -0 yq e '.config_paths.simulation_path' < "${CONFIG_YAML}" + refute_output null + SIMULATION="${output}" + + echo "simulation: off" > "${SIMULATION}" + run -0 cscli simulation status -o human + assert_output --partial "global simulation: disabled" + + echo "simulation: on" > "${SIMULATION}" + run -0 cscli simulation status -o human + assert_output --partial "global simulation: enabled" + + echo "simulation: off" > "${SIMULATION}.local" + run -0 cscli simulation status -o human + assert_output --partial "global simulation: disabled" + + rm -f "${SIMULATION}.local" + run -0 cscli simulation status -o human + assert_output --partial "global simulation: enabled" +} + + +@test "${FILE} profiles.yaml.local" { + run -0 yq e '.api.server.profiles_path' < "${CONFIG_YAML}" + refute_output null + PROFILES="${output}" + + cat <<-EOT > "${PROFILES}.local" + name: default_ip_remediation + filters: + - Alert.Remediation == true && Alert.GetScope() == "Ip" + decisions: + - type: captcha + duration: 2h + on_success: break + EOT + + tmpfile=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp) + touch "${tmpfile}" + ACQUIS_YAML=$(config_yq '.crowdsec_service.acquisition_path') + echo -e "---\nfilename: ${tmpfile}\nlabels:\n type: syslog\n" >>"${ACQUIS_YAML}" + + ./instance-crowdsec start + sleep 1 + fake_log >>"${tmpfile}" + sleep 1 + rm -f -- "${tmpfile}" + run -0 cscli decisions list -o json + run -0 jq -c '.[].decisions[0] | [.value,.type]' <(output) + assert_output '["1.1.1.172","captcha"]' +} diff --git a/tests/bats/40_live-ban.bats b/tests/bats/40_live-ban.bats index 0cb38b46c..d78f43869 100644 --- a/tests/bats/40_live-ban.bats +++ b/tests/bats/40_live-ban.bats @@ -30,7 +30,6 @@ teardown() { #---------- @test "$FILE 1.1.1.172 has been banned" { - skip tmpfile=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp) touch "${tmpfile}" ACQUIS_YAML=$(config_yq '.crowdsec_service.acquisition_path') diff --git a/tests/lib/util/wait-for-port b/tests/lib/util/wait-for-port index 47102edaa..f08d59930 100755 --- a/tests/lib/util/wait-for-port +++ b/tests/lib/util/wait-for-port @@ -1,6 +1,7 @@ #!/usr/bin/env bash set -eu + script_name=$0 die() { @@ -9,23 +10,31 @@ die() { } about() { - die "usage: $script_name " + die "usage: ${script_name} [-q] " } [ $# -lt 1 ] && about +QUIET= +if [[ "$1" == "-q" ]]; then + QUIET=quiet + shift +fi + +[ $# -lt 1 ] && about + port_number=$1 for _ in $(seq 40); do - nc -z localhost "$port_number" >/dev/null 2>&1 && exit 0 - sleep .05 + nc -z localhost "${port_number}" >/dev/null 2>&1 && exit 0 + sleep .03 done # send to &3 if open if { true >&3; } 2>/dev/null; then - echo "Can't connect to port $port_number" >&3 + [[ -z "${QUIET}" ]] && echo "Can't connect to port ${port_number}" >&3 else - echo "Can't connect to port $port_number" >&2 + [[ -z "${QUIET}" ]] && echo "Can't connect to port ${port_number}" >&2 fi exit 1