From 6e228f3f3faedc6668b1bcb1369024fc28f34208 Mon Sep 17 00:00:00 2001 From: Manuel Sabban Date: Mon, 9 Oct 2023 13:26:34 +0200 Subject: [PATCH] pkg/cwhub: cleanup in argument call (#2527) * cleanup in argument call * update test as well * cwhub_tests: reduce verbosity and use helpers --------- Co-authored-by: Marco Mariani --- cmd/crowdsec-cli/config_restore.go | 21 ++- pkg/cwhub/cwhub_test.go | 246 +++++++++-------------------- pkg/cwhub/download.go | 62 ++++---- pkg/cwhub/helpers.go | 26 ++- pkg/cwhub/install.go | 77 ++++----- 5 files changed, 171 insertions(+), 261 deletions(-) diff --git a/cmd/crowdsec-cli/config_restore.go b/cmd/crowdsec-cli/config_restore.go index a9b88f308..395e943bc 100644 --- a/cmd/crowdsec-cli/config_restore.go +++ b/cmd/crowdsec-cli/config_restore.go @@ -27,29 +27,28 @@ func silentInstallItem(name string, obtype string) (string, error) { if item == nil { return "", fmt.Errorf("error retrieving item") } - it := *item - if downloadOnly && it.Downloaded && it.UpToDate { - return fmt.Sprintf("%s is already downloaded and up-to-date", it.Name), nil + if downloadOnly && item.Downloaded && item.UpToDate { + return fmt.Sprintf("%s is already downloaded and up-to-date", item.Name), nil } - it, err := cwhub.DownloadLatest(csConfig.Hub, it, forceAction, false) + err := cwhub.DownloadLatest(csConfig.Hub, item, forceAction, false) if err != nil { - return "", fmt.Errorf("error while downloading %s : %v", it.Name, err) + return "", fmt.Errorf("error while downloading %s : %v", item.Name, err) } - if err := cwhub.AddItem(obtype, it); err != nil { + if err := cwhub.AddItem(obtype, *item); err != nil { return "", err } if downloadOnly { - return fmt.Sprintf("Downloaded %s to %s", it.Name, csConfig.Cscli.HubDir+"/"+it.RemotePath), nil + return fmt.Sprintf("Downloaded %s to %s", item.Name, csConfig.Cscli.HubDir+"/"+item.RemotePath), nil } - it, err = cwhub.EnableItem(csConfig.Hub, it) + err = cwhub.EnableItem(csConfig.Hub, item) if err != nil { - return "", fmt.Errorf("error while enabling %s : %v", it.Name, err) + return "", fmt.Errorf("error while enabling %s : %v", item.Name, err) } - if err := cwhub.AddItem(obtype, it); err != nil { + if err := cwhub.AddItem(obtype, *item); err != nil { return "", err } - return fmt.Sprintf("Enabled %s", it.Name), nil + return fmt.Sprintf("Enabled %s", item.Name), nil } func restoreHub(dirPath string) error { diff --git a/pkg/cwhub/cwhub_test.go b/pkg/cwhub/cwhub_test.go index 5de7b5f4c..1d5efe9d7 100644 --- a/pkg/cwhub/cwhub_test.go +++ b/pkg/cwhub/cwhub_test.go @@ -1,7 +1,6 @@ package cwhub import ( - "fmt" "io" "net/http" "os" @@ -10,6 +9,10 @@ import ( "testing" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/crowdsecurity/go-cs-lib/cstest" "github.com/crowdsecurity/crowdsec/pkg/csconfig" ) @@ -29,28 +32,21 @@ func TestItemStatus(t *testing.T) { cfg := envSetup(t) defer envTearDown(cfg) - err := UpdateHubIdx(cfg.Hub) // DownloadHubIdx() - if err != nil { - t.Fatalf("failed to download index : %s", err) - } + err := UpdateHubIdx(cfg.Hub) + require.NoError(t, err, "failed to download index") - if err := GetHubIdx(cfg.Hub); err != nil { - t.Fatalf("failed to load hub index : %s", err) - } + err = GetHubIdx(cfg.Hub) + require.NoError(t, err, "failed to load hub index") // get existing map x := GetItemMap(COLLECTIONS) - if len(x) == 0 { - t.Fatalf("expected non empty result") - } + require.NotEmpty(t, x) // Get item : good and bad for k := range x { item := GetItem(COLLECTIONS, k) - if item == nil { - t.Fatalf("expected item") - } + require.NotNil(t, item) item.Installed = true item.UpToDate = false @@ -58,9 +54,7 @@ func TestItemStatus(t *testing.T) { item.Tainted = false txt, _ := item.status() - if txt != "enabled,update-available" { - t.Fatalf("got '%s'", txt) - } + require.Equal(t, "enabled,update-available", txt) item.Installed = false item.UpToDate = false @@ -68,11 +62,7 @@ func TestItemStatus(t *testing.T) { item.Tainted = false txt, _ = item.status() - if txt != "disabled,local" { - t.Fatalf("got '%s'", txt) - } - - break + require.Equal(t, "disabled,local", txt) } DisplaySummary() @@ -82,62 +72,39 @@ func TestGetters(t *testing.T) { cfg := envSetup(t) defer envTearDown(cfg) - err := UpdateHubIdx(cfg.Hub) // DownloadHubIdx() - if err != nil { - t.Fatalf("failed to download index : %s", err) - } + err := UpdateHubIdx(cfg.Hub) + require.NoError(t, err, "failed to download index") - if err := GetHubIdx(cfg.Hub); err != nil { - t.Fatalf("failed to load hub index : %s", err) - } + err = GetHubIdx(cfg.Hub) + require.NoError(t, err, "failed to load hub index") // get non existing map empty := GetItemMap("ratata") - if empty != nil { - t.Fatalf("expected nil result") - } + require.Nil(t, empty) // get existing map x := GetItemMap(COLLECTIONS) - if len(x) == 0 { - t.Fatalf("expected non empty result") - } + require.NotEmpty(t, x) // Get item : good and bad for k := range x { empty := GetItem(COLLECTIONS, k+"nope") - if empty != nil { - t.Fatalf("expected empty item") - } + require.Nil(t, empty) item := GetItem(COLLECTIONS, k) - if item == nil { - t.Fatalf("expected non empty item") - } + require.NotNil(t, item) // Add item and get it item.Name += "nope" - if err := AddItem(COLLECTIONS, *item); err != nil { - t.Fatalf("didn't expect error : %s", err) - } + err := AddItem(COLLECTIONS, *item) + require.NoError(t, err) newitem := GetItem(COLLECTIONS, item.Name) - if newitem == nil { - t.Fatalf("expected non empty item") - } + require.NotNil(t, newitem) - // Add bad item - err := AddItem("ratata", *item) - if err == nil { - t.Fatalf("Expected error") - } - - if fmt.Sprintf("%s", err) != "ItemType ratata is unknown" { - t.Fatalf("unexpected error") - } - - break + err = AddItem("ratata", *item) + cstest.RequireErrorContains(t, err, "ItemType ratata is unknown") } } @@ -145,15 +112,12 @@ func TestIndexDownload(t *testing.T) { cfg := envSetup(t) defer envTearDown(cfg) - err := UpdateHubIdx(cfg.Hub) // DownloadHubIdx() - if err != nil { - t.Fatalf("failed to download index : %s", err) - } + err := UpdateHubIdx(cfg.Hub) + require.NoError(t, err, "failed to download index") - if err := GetHubIdx(cfg.Hub); err != nil { - t.Fatalf("failed to load hub index : %s", err) - } + err = GetHubIdx(cfg.Hub) + require.NoError(t, err, "failed to load hub index") } func getTestCfg() *csconfig.Config { @@ -180,17 +144,14 @@ func envSetup(t *testing.T) *csconfig.Config { // Mock the http client http.DefaultClient.Transport = newMockTransport() - if err := os.MkdirAll(cfg.Hub.ConfigDir, 0700); err != nil { - log.Fatalf("mkdir : %s", err) - } + err := os.MkdirAll(cfg.Hub.ConfigDir, 0700) + require.NoError(t, err) - if err := os.MkdirAll(cfg.Hub.HubDir, 0700); err != nil { - log.Fatalf("failed to mkdir %s : %s", cfg.Hub.HubDir, err) - } + err = os.MkdirAll(cfg.Hub.HubDir, 0700) + require.NoError(t, err) - if err := UpdateHubIdx(cfg.Hub); err != nil { - log.Fatalf("failed to download index : %s", err) - } + err = UpdateHubIdx(cfg.Hub) + require.NoError(t, err) // if err := os.RemoveAll(cfg.Hub.InstallDir); err != nil { // log.Fatalf("failed to remove %s : %s", cfg.Hub.InstallDir, err) @@ -213,137 +174,86 @@ func envTearDown(cfg *csconfig.Config) { func testInstallItem(cfg *csconfig.Hub, t *testing.T, item Item) { // Install the parser - item, err := DownloadLatest(cfg, item, false, false) - if err != nil { - t.Fatalf("error while downloading %s : %v", item.Name, err) - } + err := DownloadLatest(cfg, &item, false, false) + require.NoError(t, err, "failed to download %s", item.Name) - if err, _ := LocalSync(cfg); err != nil { - t.Fatalf("taint: failed to run localSync : %s", err) - } + err, _ = LocalSync(cfg) + require.NoError(t, err, "failed to run localSync") - if !hubIdx[item.Type][item.Name].UpToDate { - t.Fatalf("download: %s should be up-to-date", item.Name) - } + assert.True(t, hubIdx[item.Type][item.Name].UpToDate, "%s should be up-to-date", item.Name) + assert.False(t, hubIdx[item.Type][item.Name].Installed, "%s should not be installed", item.Name) + assert.False(t, hubIdx[item.Type][item.Name].Tainted, "%s should not be tainted", item.Name) - if hubIdx[item.Type][item.Name].Installed { - t.Fatalf("download: %s should not be installed", item.Name) - } + err = EnableItem(cfg, &item) + require.NoError(t, err, "failed to enable %s", item.Name) - if hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("download: %s should not be tainted", item.Name) - } + err, _ = LocalSync(cfg) + require.NoError(t, err, "failed to run localSync") - item, err = EnableItem(cfg, item) - if err != nil { - t.Fatalf("error while enabling %s : %v.", item.Name, err) - } - - if err, _ := LocalSync(cfg); err != nil { - t.Fatalf("taint: failed to run localSync : %s", err) - } - - if !hubIdx[item.Type][item.Name].Installed { - t.Fatalf("install: %s should be installed", item.Name) - } + assert.True(t, hubIdx[item.Type][item.Name].Installed, "%s should be installed", item.Name) } func testTaintItem(cfg *csconfig.Hub, t *testing.T, item Item) { - if hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("pre-taint: %s should not be tainted", item.Name) - } + assert.False(t, hubIdx[item.Type][item.Name].Tainted, "%s should not be tainted", item.Name) f, err := os.OpenFile(item.LocalPath, os.O_APPEND|os.O_WRONLY, 0600) - if err != nil { - t.Fatalf("(taint) opening %s (%s) : %s", item.LocalPath, item.Name, err) - } + require.NoError(t, err, "failed to open %s (%s)", item.LocalPath, item.Name) + defer f.Close() - if _, err = f.WriteString("tainted"); err != nil { - t.Fatalf("tainting %s : %s", item.Name, err) - } + _, err = f.WriteString("tainted") + require.NoError(t, err, "failed to write to %s (%s)", item.LocalPath, item.Name) // Local sync and check status - if err, _ := LocalSync(cfg); err != nil { - t.Fatalf("taint: failed to run localSync : %s", err) - } + err, _ = LocalSync(cfg) + require.NoError(t, err, "failed to run localSync") - if !hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("taint: %s should be tainted", item.Name) - } + assert.True(t, hubIdx[item.Type][item.Name].Tainted, "%s should be tainted", item.Name) } func testUpdateItem(cfg *csconfig.Hub, t *testing.T, item Item) { - if hubIdx[item.Type][item.Name].UpToDate { - t.Fatalf("update: %s should NOT be up-to-date", item.Name) - } + assert.False(t, hubIdx[item.Type][item.Name].UpToDate, "%s should not be up-to-date", item.Name) // Update it + check status - item, err := DownloadLatest(cfg, item, true, true) - if err != nil { - t.Fatalf("failed to update %s : %s", item.Name, err) - } + err := DownloadLatest(cfg, &item, true, true) + require.NoError(t, err, "failed to update %s", item.Name) // Local sync and check status - if err, _ := LocalSync(cfg); err != nil { - t.Fatalf("failed to run localSync : %s", err) - } + err, _ = LocalSync(cfg) + require.NoError(t, err, "failed to run localSync") - if !hubIdx[item.Type][item.Name].UpToDate { - t.Fatalf("update: %s should be up-to-date", item.Name) - } - - if hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("update: %s should not be tainted anymore", item.Name) - } + assert.True(t, hubIdx[item.Type][item.Name].UpToDate, "%s should be up-to-date", item.Name) + assert.False(t, hubIdx[item.Type][item.Name].Tainted, "%s should not be tainted anymore", item.Name) } func testDisableItem(cfg *csconfig.Hub, t *testing.T, item Item) { - if !item.Installed { - t.Fatalf("disable: %s should be installed", item.Name) - } + assert.True(t, hubIdx[item.Type][item.Name].Installed, "%s should be installed", item.Name) // Remove - item, err := DisableItem(cfg, item, false, false) - if err != nil { - t.Fatalf("failed to disable item : %v", err) - } + err := DisableItem(cfg, &item, false, false) + require.NoError(t, err, "failed to disable %s", item.Name) // Local sync and check status - if err, warns := LocalSync(cfg); err != nil || len(warns) > 0 { - t.Fatalf("failed to run localSync : %s (%+v)", err, warns) - } + err, warns := LocalSync(cfg) + require.NoError(t, err, "failed to run localSync") + require.Empty(t, warns, "unexpected warnings : %+v", warns) - if hubIdx[item.Type][item.Name].Tainted { - t.Fatalf("disable: %s should not be tainted anymore", item.Name) - } - - if hubIdx[item.Type][item.Name].Installed { - t.Fatalf("disable: %s should not be installed anymore", item.Name) - } - - if !hubIdx[item.Type][item.Name].Downloaded { - t.Fatalf("disable: %s should still be downloaded", item.Name) - } + assert.False(t, hubIdx[item.Type][item.Name].Tainted, "%s should not be tainted anymore", item.Name) + assert.False(t, hubIdx[item.Type][item.Name].Installed, "%s should not be installed anymore", item.Name) + assert.True(t, hubIdx[item.Type][item.Name].Downloaded, "%s should still be downloaded", item.Name) // Purge - item, err = DisableItem(cfg, item, true, false) - if err != nil { - t.Fatalf("failed to purge item : %v", err) - } + err = DisableItem(cfg, &item, true, false) + require.NoError(t, err, "failed to purge %s", item.Name) // Local sync and check status - if err, warns := LocalSync(cfg); err != nil || len(warns) > 0 { - t.Fatalf("failed to run localSync : %s (%+v)", err, warns) - } + err, warns = LocalSync(cfg) + require.NoError(t, err, "failed to run localSync") + require.Empty(t, warns, "unexpected warnings : %+v", warns) - if hubIdx[item.Type][item.Name].Installed { - t.Fatalf("disable: %s should not be installed anymore", item.Name) - } - if hubIdx[item.Type][item.Name].Downloaded { - t.Fatalf("disable: %s should not be downloaded", item.Name) - } + assert.False(t, hubIdx[item.Type][item.Name].Installed, "%s should not be installed anymore", item.Name) + assert.False(t, hubIdx[item.Type][item.Name].Downloaded, "%s should not be downloaded", item.Name) } func TestInstallParser(t *testing.T) { diff --git a/pkg/cwhub/download.go b/pkg/cwhub/download.go index 1f809a052..c5c69b444 100644 --- a/pkg/cwhub/download.go +++ b/pkg/cwhub/download.go @@ -96,7 +96,7 @@ func DownloadHubIdx(hub *csconfig.Hub) ([]byte, error) { } // DownloadLatest will download the latest version of Item to the tdir directory -func DownloadLatest(hub *csconfig.Hub, target Item, overwrite bool, updateOnly bool) (Item, error) { +func DownloadLatest(hub *csconfig.Hub, target *Item, overwrite bool, updateOnly bool) error { var err error log.Debugf("Downloading %s %s", target.Type, target.Name) @@ -104,7 +104,7 @@ func DownloadLatest(hub *csconfig.Hub, target Item, overwrite bool, updateOnly b if target.Type != COLLECTIONS { if !target.Installed && updateOnly && target.Downloaded { log.Debugf("skipping upgrade of %s : not installed", target.Name) - return target, nil + return nil } return DownloadItem(hub, target, overwrite) @@ -117,7 +117,7 @@ func DownloadLatest(hub *csconfig.Hub, target Item, overwrite bool, updateOnly b for _, p := range ptr { val, ok := hubIdx[ptrtype][p] if !ok { - return target, fmt.Errorf("required %s %s of %s doesn't exist, abort", ptrtype, p, target.Name) + return fmt.Errorf("required %s %s of %s doesn't exist, abort", ptrtype, p, target.Name) } if !val.Installed && updateOnly && val.Downloaded { @@ -130,38 +130,38 @@ func DownloadLatest(hub *csconfig.Hub, target Item, overwrite bool, updateOnly b if ptrtype == COLLECTIONS { log.Tracef("collection, recurse") - hubIdx[ptrtype][p], err = DownloadLatest(hub, val, overwrite, updateOnly) + err = DownloadLatest(hub, &val, overwrite, updateOnly) if err != nil { - return target, fmt.Errorf("while downloading %s: %w", val.Name, err) + return fmt.Errorf("while downloading %s: %w", val.Name, err) } } - - item, err := DownloadItem(hub, val, overwrite) + downloaded := val.Downloaded + err := DownloadItem(hub, &val, overwrite) if err != nil { - return target, fmt.Errorf("while downloading %s: %w", val.Name, err) + return fmt.Errorf("while downloading %s: %w", val.Name, err) } // We need to enable an item when it has been added to a collection since latest release of the collection. // We check if val.Downloaded is false because maybe the item has been disabled by the user. - if !item.Installed && !val.Downloaded { - if item, err = EnableItem(hub, item); err != nil { - return target, fmt.Errorf("enabling '%s': %w", item.Name, err) + if !val.Installed && !downloaded { + if err = EnableItem(hub, &val); err != nil { + return fmt.Errorf("enabling '%s': %w", val.Name, err) } } - hubIdx[ptrtype][p] = item + hubIdx[ptrtype][p] = val } } - target, err = DownloadItem(hub, target, overwrite) + err = DownloadItem(hub, target, overwrite) if err != nil { - return target, fmt.Errorf("failed to download item: %w", err) + return fmt.Errorf("failed to download item: %w", err) } - return target, nil + return nil } -func DownloadItem(hub *csconfig.Hub, target Item, overwrite bool) (Item, error) { +func DownloadItem(hub *csconfig.Hub, target *Item, overwrite bool) error { tdir := hub.HubDir dataFolder := hub.DataDir @@ -169,7 +169,7 @@ func DownloadItem(hub *csconfig.Hub, target Item, overwrite bool) (Item, error) if !overwrite { if target.Tainted { log.Debugf("%s : tainted, not updated", target.Name) - return target, nil + return nil } if target.UpToDate { @@ -180,28 +180,28 @@ func DownloadItem(hub *csconfig.Hub, target Item, overwrite bool) (Item, error) req, err := http.NewRequest(http.MethodGet, fmt.Sprintf(RawFileURLTemplate, HubBranch, target.RemotePath), nil) if err != nil { - return target, fmt.Errorf("while downloading %s: %w", req.URL.String(), err) + return fmt.Errorf("while downloading %s: %w", req.URL.String(), err) } resp, err := http.DefaultClient.Do(req) if err != nil { - return target, fmt.Errorf("while downloading %s: %w", req.URL.String(), err) + return fmt.Errorf("while downloading %s: %w", req.URL.String(), err) } if resp.StatusCode != http.StatusOK { - return target, fmt.Errorf("bad http code %d for %s", resp.StatusCode, req.URL.String()) + return fmt.Errorf("bad http code %d for %s", resp.StatusCode, req.URL.String()) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return target, fmt.Errorf("while reading %s: %w", req.URL.String(), err) + return fmt.Errorf("while reading %s: %w", req.URL.String(), err) } h := sha256.New() if _, err := h.Write(body); err != nil { - return target, fmt.Errorf("while hashing %s: %w", target.Name, err) + return fmt.Errorf("while hashing %s: %w", target.Name, err) } meow := fmt.Sprintf("%x", h.Sum(nil)) @@ -209,7 +209,7 @@ func DownloadItem(hub *csconfig.Hub, target Item, overwrite bool) (Item, error) log.Errorf("Downloaded version doesn't match index, please 'hub update'") log.Debugf("got %s, expected %s", meow, target.Versions[target.Version].Digest) - return target, fmt.Errorf("invalid download hash for %s", target.Name) + return fmt.Errorf("invalid download hash for %s", target.Name) } //all good, install @@ -220,11 +220,11 @@ func DownloadItem(hub *csconfig.Hub, target Item, overwrite bool) (Item, error) // ensure that target file is within target dir finalPath, err := filepath.Abs(tdir + "/" + target.RemotePath) if err != nil { - return target, fmt.Errorf("filepath.Abs error on %s: %w", tdir+"/"+target.RemotePath, err) + return fmt.Errorf("filepath.Abs error on %s: %w", tdir+"/"+target.RemotePath, err) } if !strings.HasPrefix(finalPath, tdir) { - return target, fmt.Errorf("path %s escapes %s, abort", target.RemotePath, tdir) + return fmt.Errorf("path %s escapes %s, abort", target.RemotePath, tdir) } // check dir @@ -232,7 +232,7 @@ func DownloadItem(hub *csconfig.Hub, target Item, overwrite bool) (Item, error) log.Debugf("%s doesn't exist, create", parent_dir) if err := os.MkdirAll(parent_dir, os.ModePerm); err != nil { - return target, fmt.Errorf("while creating parent directories: %w", err) + return fmt.Errorf("while creating parent directories: %w", err) } } @@ -246,14 +246,14 @@ func DownloadItem(hub *csconfig.Hub, target Item, overwrite bool) (Item, error) f, err := os.OpenFile(tdir+"/"+target.RemotePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { - return target, fmt.Errorf("while opening file: %w", err) + return fmt.Errorf("while opening file: %w", err) } defer f.Close() _, err = f.Write(body) if err != nil { - return target, fmt.Errorf("while writing file: %w", err) + return fmt.Errorf("while writing file: %w", err) } target.Downloaded = true @@ -261,12 +261,12 @@ func DownloadItem(hub *csconfig.Hub, target Item, overwrite bool) (Item, error) target.UpToDate = true if err = downloadData(dataFolder, overwrite, bytes.NewReader(body)); err != nil { - return target, fmt.Errorf("while downloading data for %s: %w", target.FileName, err) + return fmt.Errorf("while downloading data for %s: %w", target.FileName, err) } - hubIdx[target.Type][target.Name] = target + hubIdx[target.Type][target.Name] = *target - return target, nil + return nil } func DownloadDataIfNeeded(hub *csconfig.Hub, target Item, force bool) error { diff --git a/pkg/cwhub/helpers.go b/pkg/cwhub/helpers.go index c1a9042d3..e9ef5e200 100644 --- a/pkg/cwhub/helpers.go +++ b/pkg/cwhub/helpers.go @@ -62,12 +62,11 @@ func SetHubBranch() { } func InstallItem(csConfig *csconfig.Config, name string, obtype string, force bool, downloadOnly bool) error { - it := GetItem(obtype, name) - if it == nil { + item := GetItem(obtype, name) + if item == nil { return fmt.Errorf("unable to retrieve item: %s", name) } - item := *it if downloadOnly && item.Downloaded && item.UpToDate { log.Warningf("%s is already downloaded and up-to-date", item.Name) @@ -76,12 +75,12 @@ func InstallItem(csConfig *csconfig.Config, name string, obtype string, force bo } } - item, err := DownloadLatest(csConfig.Hub, item, force, true) + err := DownloadLatest(csConfig.Hub, item, force, true) if err != nil { return fmt.Errorf("while downloading %s: %w", item.Name, err) } - if err := AddItem(obtype, item); err != nil { + if err := AddItem(obtype, *item); err != nil { return fmt.Errorf("while adding %s: %w", item.Name, err) } @@ -90,12 +89,12 @@ func InstallItem(csConfig *csconfig.Config, name string, obtype string, force bo return nil } - item, err = EnableItem(csConfig.Hub, item) + err = EnableItem(csConfig.Hub, item) if err != nil { return fmt.Errorf("while enabling %s: %w", item.Name, err) } - if err := AddItem(obtype, item); err != nil { + if err := AddItem(obtype, *item); err != nil { return fmt.Errorf("while adding %s: %w", item.Name, err) } @@ -112,19 +111,18 @@ func RemoveMany(csConfig *csconfig.Config, itemType string, name string, all boo ) if name != "" { - it := GetItem(itemType, name) - if it == nil { + item := GetItem(itemType, name) + if item == nil { log.Fatalf("unable to retrieve: %s", name) } - item := *it - item, err = DisableItem(csConfig.Hub, item, purge, forceAction) + err = DisableItem(csConfig.Hub, item, purge, forceAction) if err != nil { log.Fatalf("unable to disable %s : %v", item.Name, err) } - if err := AddItem(itemType, item); err != nil { + if err := AddItem(itemType, *item); err != nil { log.Fatalf("unable to add %s: %v", item.Name, err) } @@ -141,7 +139,7 @@ func RemoveMany(csConfig *csconfig.Config, itemType string, name string, all boo continue } - v, err = DisableItem(csConfig.Hub, v, purge, forceAction) + err = DisableItem(csConfig.Hub, &v, purge, forceAction) if err != nil { log.Fatalf("unable to disable %s : %v", v.Name, err) } @@ -191,7 +189,7 @@ func UpgradeConfig(csConfig *csconfig.Config, itemType string, name string, forc } } - v, err = DownloadLatest(csConfig.Hub, v, force, true) + err = DownloadLatest(csConfig.Hub, &v, force, true) if err != nil { log.Fatalf("%s : download failed : %v", v.Name, err) } diff --git a/pkg/cwhub/install.go b/pkg/cwhub/install.go index 0c18e8a86..b029bc108 100644 --- a/pkg/cwhub/install.go +++ b/pkg/cwhub/install.go @@ -26,31 +26,28 @@ func purgeItem(hub *csconfig.Hub, target Item) (Item, error) { } // DisableItem to disable an item managed by the hub, removes the symlink if purge is true -func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item, error) { +func DisableItem(hub *csconfig.Hub, target *Item, purge bool, force bool) error { var err error + // already disabled, noop unless purge if !target.Installed { if purge { - target, err = purgeItem(hub, target) + *target, err = purgeItem(hub, *target) if err != nil { - return target, err + return err } } - return target, nil + return nil } - syml, err := filepath.Abs(hub.ConfigDir + "/" + target.Type + "/" + target.Stage + "/" + target.FileName) - if err != nil { - return Item{}, err - } if target.Local { - return target, fmt.Errorf("%s isn't managed by hub. Please delete manually", target.Name) + return fmt.Errorf("%s isn't managed by hub. Please delete manually", target.Name) } if target.Tainted && !force { - return target, fmt.Errorf("%s is tainted, use '--force' to overwrite", target.Name) + return fmt.Errorf("%s is tainted, use '--force' to overwrite", target.Name) } // for a COLLECTIONS, disable sub-items @@ -71,9 +68,9 @@ func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item, } if toRemove { - hubIdx[ptrtype][p], err = DisableItem(hub, val, purge, force) + err = DisableItem(hub, &val, purge, force) if err != nil { - return target, fmt.Errorf("while disabling %s: %w", p, err) + return fmt.Errorf("while disabling %s: %w", p, err) } } else { log.Infof("%s was not removed because it belongs to another collection", val.Name) @@ -85,36 +82,42 @@ func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item, } } + syml, err := filepath.Abs(hub.ConfigDir + "/" + target.Type + "/" + target.Stage + "/" + target.FileName) + if err != nil { + return err + } + stat, err := os.Lstat(syml) if os.IsNotExist(err) { - if !purge && !force { // we only accept to "delete" non existing items if it's a purge - return target, fmt.Errorf("can't delete %s : %s doesn't exist", target.Name, syml) + // we only accept to "delete" non existing items if it's a forced purge + if !purge && !force { + return fmt.Errorf("can't delete %s : %s doesn't exist", target.Name, syml) } } else { // if it's managed by hub, it's a symlink to csconfig.GConfig.hub.HubDir / ... if stat.Mode()&os.ModeSymlink == 0 { log.Warningf("%s (%s) isn't a symlink, can't disable", target.Name, syml) - return target, fmt.Errorf("%s isn't managed by hub", target.Name) + return fmt.Errorf("%s isn't managed by hub", target.Name) } hubpath, err := os.Readlink(syml) if err != nil { - return target, fmt.Errorf("while reading symlink: %w", err) + return fmt.Errorf("while reading symlink: %w", err) } absPath, err := filepath.Abs(hub.HubDir + "/" + target.RemotePath) if err != nil { - return target, fmt.Errorf("while abs path: %w", err) + return fmt.Errorf("while abs path: %w", err) } if hubpath != absPath { log.Warningf("%s (%s) isn't a symlink to %s", target.Name, syml, absPath) - return target, fmt.Errorf("%s isn't managed by hub", target.Name) + return fmt.Errorf("%s isn't managed by hub", target.Name) } // remove the symlink if err = os.Remove(syml); err != nil { - return target, fmt.Errorf("while removing symlink: %w", err) + return fmt.Errorf("while removing symlink: %w", err) } log.Infof("Removed symlink [%s] : %s", target.Name, syml) @@ -123,20 +126,20 @@ func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item, target.Installed = false if purge { - target, err = purgeItem(hub, target) + *target, err = purgeItem(hub, *target) if err != nil { - return target, err + return err } } - hubIdx[target.Type][target.Name] = target + hubIdx[target.Type][target.Name] = *target - return target, nil + return nil } // creates symlink between actual config file at hub.HubDir and hub.ConfigDir // Handles collections recursively -func EnableItem(hub *csconfig.Hub, target Item) (Item, error) { +func EnableItem(hub *csconfig.Hub, target *Item) error { var err error parent_dir := filepath.Clean(hub.ConfigDir + "/" + target.Type + "/" + target.Stage + "/") @@ -144,17 +147,17 @@ func EnableItem(hub *csconfig.Hub, target Item) (Item, error) { // create directories if needed if target.Installed { if target.Tainted { - return target, fmt.Errorf("%s is tainted, won't enable unless --force", target.Name) + return fmt.Errorf("%s is tainted, won't enable unless --force", target.Name) } if target.Local { - return target, fmt.Errorf("%s is local, won't enable", target.Name) + return fmt.Errorf("%s is local, won't enable", target.Name) } // if it's a collection, check sub-items even if the collection file itself is up-to-date if target.UpToDate && target.Type != COLLECTIONS { log.Tracef("%s is installed and up-to-date, skip.", target.Name) - return target, nil + return nil } } @@ -162,7 +165,7 @@ func EnableItem(hub *csconfig.Hub, target Item) (Item, error) { log.Infof("%s doesn't exist, create", parent_dir) if err := os.MkdirAll(parent_dir, os.ModePerm); err != nil { - return target, fmt.Errorf("while creating directory: %w", err) + return fmt.Errorf("while creating directory: %w", err) } } @@ -174,12 +177,12 @@ func EnableItem(hub *csconfig.Hub, target Item) (Item, error) { for _, p := range ptr { val, ok := hubIdx[ptrtype][p] if !ok { - return target, fmt.Errorf("required %s %s of %s doesn't exist, abort", ptrtype, p, target.Name) + return fmt.Errorf("required %s %s of %s doesn't exist, abort", ptrtype, p, target.Name) } - hubIdx[ptrtype][p], err = EnableItem(hub, val) + err = EnableItem(hub, &val) if err != nil { - return target, fmt.Errorf("while installing %s: %w", p, err) + return fmt.Errorf("while installing %s: %w", p, err) } } } @@ -188,27 +191,27 @@ func EnableItem(hub *csconfig.Hub, target Item) (Item, error) { // check if file already exists where it should in configdir (eg /etc/crowdsec/collections/) if _, err := os.Lstat(parent_dir + "/" + target.FileName); !os.IsNotExist(err) { log.Infof("%s already exists.", parent_dir+"/"+target.FileName) - return target, nil + return nil } // hub.ConfigDir + target.RemotePath srcPath, err := filepath.Abs(hub.HubDir + "/" + target.RemotePath) if err != nil { - return target, fmt.Errorf("while getting source path: %w", err) + return fmt.Errorf("while getting source path: %w", err) } dstPath, err := filepath.Abs(parent_dir + "/" + target.FileName) if err != nil { - return target, fmt.Errorf("while getting destination path: %w", err) + return fmt.Errorf("while getting destination path: %w", err) } if err = os.Symlink(srcPath, dstPath); err != nil { - return target, fmt.Errorf("while creating symlink from %s to %s: %w", srcPath, dstPath, err) + return fmt.Errorf("while creating symlink from %s to %s: %w", srcPath, dstPath, err) } log.Infof("Enabled %s : %s", target.Type, target.Name) target.Installed = true - hubIdx[target.Type][target.Name] = target + hubIdx[target.Type][target.Name] = *target - return target, nil + return nil }