Refact pkg/cwhub (part 1) (#2512)

* wrap errors, whitespace
* remove named return
* reverse CheckSuffix logic, rename function
* drop redundant if/else, happy path
* log.Fatal -> fmt.Errorf
* simplify GetItemMap, AddItem
* var -> const
* removed short-lived vars
* de-duplicate function and reverse logic
This commit is contained in:
mmetc 2023-10-04 10:34:10 +02:00 committed by GitHub
parent 8b5ad6990d
commit d39131d154
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 88 additions and 122 deletions

View file

@ -16,17 +16,17 @@ import (
) )
// managed configuration types // managed configuration types
var PARSERS = "parsers" const PARSERS = "parsers"
var PARSERS_OVFLW = "postoverflows" const PARSERS_OVFLW = "postoverflows"
var SCENARIOS = "scenarios" const SCENARIOS = "scenarios"
var COLLECTIONS = "collections" const COLLECTIONS = "collections"
var ItemTypes = []string{PARSERS, PARSERS_OVFLW, SCENARIOS, COLLECTIONS} var ItemTypes = []string{PARSERS, PARSERS_OVFLW, SCENARIOS, COLLECTIONS}
var hubIdx map[string]map[string]Item var hubIdx map[string]map[string]Item
var RawFileURLTemplate = "https://hub-cdn.crowdsec.net/%s/%s" var RawFileURLTemplate = "https://hub-cdn.crowdsec.net/%s/%s"
var HubBranch = "master" var HubBranch = "master"
var HubIndexFile = ".index.json" const HubIndexFile = ".index.json"
type ItemVersion struct { type ItemVersion struct {
Digest string `json:"digest,omitempty"` Digest string `json:"digest,omitempty"`
@ -119,26 +119,22 @@ func getSHA256(filepath string) (string, error) {
// Digest of file // Digest of file
f, err := os.Open(filepath) f, err := os.Open(filepath)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to open '%s': %s", filepath, err) return "", fmt.Errorf("unable to open '%s': %w", filepath, err)
} }
defer f.Close() defer f.Close()
h := sha256.New() h := sha256.New()
if _, err := io.Copy(h, f); err != nil { if _, err := io.Copy(h, f); err != nil {
return "", fmt.Errorf("unable to calculate sha256 of '%s': %s", filepath, err) return "", fmt.Errorf("unable to calculate sha256 of '%s': %w", filepath, err)
} }
return fmt.Sprintf("%x", h.Sum(nil)), nil return fmt.Sprintf("%x", h.Sum(nil)), nil
} }
func GetItemMap(itemType string) map[string]Item { func GetItemMap(itemType string) map[string]Item {
var ( m, ok := hubIdx[itemType]
m map[string]Item if !ok {
ok bool
)
if m, ok = hubIdx[itemType]; !ok {
return nil return nil
} }
@ -194,22 +190,15 @@ func GetItem(itemType string, itemName string) *Item {
} }
func AddItem(itemType string, item Item) error { func AddItem(itemType string, item Item) error {
in := false
for _, itype := range ItemTypes { for _, itype := range ItemTypes {
if itype == itemType { if itype == itemType {
in = true
}
}
if !in {
return fmt.Errorf("ItemType %s is unknown", itemType)
}
hubIdx[itemType][item.Name] = item hubIdx[itemType][item.Name] = item
return nil return nil
} }
}
return fmt.Errorf("ItemType %s is unknown", itemType)
}
func DisplaySummary() { func DisplaySummary() {
log.Printf("Loaded %d collecs, %d parsers, %d scenarios, %d post-overflow parsers", len(hubIdx[COLLECTIONS]), log.Printf("Loaded %d collecs, %d parsers, %d scenarios, %d post-overflow parsers", len(hubIdx[COLLECTIONS]),
@ -242,8 +231,8 @@ func ItemStatus(v Item) (string, bool, bool, bool) {
Warning = true Warning = true
strret += ",tainted" strret += ",tainted"
} else if !v.UpToDate && !v.Local { } else if !v.UpToDate && !v.Local {
strret += ",update-available"
Warning = true Warning = true
strret += ",update-available"
} }
return strret, Ok, Warning, Managed return strret, Ok, Warning, Managed

View file

@ -128,13 +128,14 @@ func TestGetters(t *testing.T) {
} }
// Add bad item // Add bad item
if err := AddItem("ratata", *item); err != nil { err := AddItem("ratata", *item)
if err == nil {
t.Fatalf("Expected error")
}
if fmt.Sprintf("%s", err) != "ItemType ratata is unknown" { if fmt.Sprintf("%s", err) != "ItemType ratata is unknown" {
t.Fatalf("unexpected error") t.Fatalf("unexpected error")
} }
} else {
t.Fatalf("Expected error")
}
break break
} }
@ -155,13 +156,13 @@ func TestIndexDownload(t *testing.T) {
} }
} }
func getTestCfg() (cfg *csconfig.Config) { func getTestCfg() *csconfig.Config {
cfg = &csconfig.Config{Hub: &csconfig.Hub{}} cfg := &csconfig.Config{Hub: &csconfig.Hub{}}
cfg.Hub.ConfigDir, _ = filepath.Abs("./install") cfg.Hub.ConfigDir, _ = filepath.Abs("./install")
cfg.Hub.HubDir, _ = filepath.Abs("./hubdir") cfg.Hub.HubDir, _ = filepath.Abs("./hubdir")
cfg.Hub.HubIndexFile = filepath.Clean("./hubdir/.index.json") cfg.Hub.HubIndexFile = filepath.Clean("./hubdir/.index.json")
return return cfg
} }
func envSetup(t *testing.T) *csconfig.Config { func envSetup(t *testing.T) *csconfig.Config {
@ -424,34 +425,32 @@ func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
} }
response.Header.Set("Content-Type", "application/json") response.Header.Set("Content-Type", "application/json")
responseBody := ""
log.Printf("---> %s", req.URL.Path) log.Printf("---> %s", req.URL.Path)
// FAKE PARSER // FAKE PARSER
if resp, ok := responseByPath[req.URL.Path]; ok { resp, ok := responseByPath[req.URL.Path]
responseBody = resp if !ok {
} else {
log.Fatalf("unexpected url :/ %s", req.URL.Path) log.Fatalf("unexpected url :/ %s", req.URL.Path)
} }
response.Body = io.NopCloser(strings.NewReader(responseBody)) response.Body = io.NopCloser(strings.NewReader(resp))
return response, nil return response, nil
} }
func fileToStringX(path string) string { func fileToStringX(path string) string {
if f, err := os.Open(path); err == nil { f, err := os.Open(path)
if err != nil {
panic(err)
}
defer f.Close() defer f.Close()
if data, err := io.ReadAll(f); err == nil { data, err := io.ReadAll(f)
if err != nil {
panic(err)
}
return strings.ReplaceAll(string(data), "\r\n", "\n") return strings.ReplaceAll(string(data), "\r\n", "\n")
} else {
panic(err)
}
} else {
panic(err)
}
} }
func resetResponseByPath() { func resetResponseByPath() {

View file

@ -155,7 +155,7 @@ func DownloadLatest(hub *csconfig.Hub, target Item, overwrite bool, updateOnly b
target, err = DownloadItem(hub, target, overwrite) target, err = DownloadItem(hub, target, overwrite)
if err != nil { if err != nil {
return target, fmt.Errorf("failed to download item : %s", err) return target, fmt.Errorf("failed to download item: %w", err)
} }
return target, nil return target, nil

View file

@ -11,16 +11,15 @@ import (
) )
func purgeItem(hub *csconfig.Hub, target Item) (Item, error) { func purgeItem(hub *csconfig.Hub, target Item) (Item, error) {
var hdir = hub.HubDir itempath := hub.HubDir + "/" + target.RemotePath
hubpath := hdir + "/" + target.RemotePath
// disable hub file // disable hub file
if err := os.Remove(hubpath); err != nil { if err := os.Remove(itempath); err != nil {
return target, fmt.Errorf("while removing file: %w", err) return target, fmt.Errorf("while removing file: %w", err)
} }
target.Downloaded = false target.Downloaded = false
log.Infof("Removed source file [%s] : %s", target.Name, hubpath) log.Infof("Removed source file [%s]: %s", target.Name, itempath)
hubIdx[target.Type][target.Name] = target hubIdx[target.Type][target.Name] = target
return target, nil return target, nil
@ -30,9 +29,6 @@ func purgeItem(hub *csconfig.Hub, target Item) (Item, error) {
func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item, error) { func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item, error) {
var err error var err error
tdir := hub.ConfigDir
hdir := hub.HubDir
if !target.Installed { if !target.Installed {
if purge { if purge {
target, err = purgeItem(hub, target) target, err = purgeItem(hub, target)
@ -44,7 +40,7 @@ func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item,
return target, nil return target, nil
} }
syml, err := filepath.Abs(tdir + "/" + target.Type + "/" + target.Stage + "/" + target.FileName) syml, err := filepath.Abs(hub.ConfigDir + "/" + target.Type + "/" + target.Stage + "/" + target.FileName)
if err != nil { if err != nil {
return Item{}, err return Item{}, err
} }
@ -100,14 +96,17 @@ func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item,
log.Warningf("%s (%s) isn't a symlink, can't disable", target.Name, syml) log.Warningf("%s (%s) isn't a symlink, can't disable", target.Name, syml)
return target, fmt.Errorf("%s isn't managed by hub", target.Name) return target, fmt.Errorf("%s isn't managed by hub", target.Name)
} }
hubpath, err := os.Readlink(syml) hubpath, err := os.Readlink(syml)
if err != nil { if err != nil {
return target, fmt.Errorf("while reading symlink: %w", err) return target, fmt.Errorf("while reading symlink: %w", err)
} }
absPath, err := filepath.Abs(hdir + "/" + target.RemotePath)
absPath, err := filepath.Abs(hub.HubDir + "/" + target.RemotePath)
if err != nil { if err != nil {
return target, fmt.Errorf("while abs path: %w", err) return target, fmt.Errorf("while abs path: %w", err)
} }
if hubpath != absPath { if hubpath != absPath {
log.Warningf("%s (%s) isn't a symlink to %s", target.Name, syml, absPath) log.Warningf("%s (%s) isn't a symlink to %s", target.Name, syml, absPath)
return target, fmt.Errorf("%s isn't managed by hub", target.Name) return target, fmt.Errorf("%s isn't managed by hub", target.Name)
@ -117,6 +116,7 @@ func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item,
if err = os.Remove(syml); err != nil { if err = os.Remove(syml); err != nil {
return target, fmt.Errorf("while removing symlink: %w", err) return target, fmt.Errorf("while removing symlink: %w", err)
} }
log.Infof("Removed symlink [%s] : %s", target.Name, syml) log.Infof("Removed symlink [%s] : %s", target.Name, syml)
} }
@ -139,10 +139,8 @@ func DisableItem(hub *csconfig.Hub, target Item, purge bool, force bool) (Item,
func EnableItem(hub *csconfig.Hub, target Item) (Item, error) { func EnableItem(hub *csconfig.Hub, target Item) (Item, error) {
var err error var err error
tdir := hub.ConfigDir parent_dir := filepath.Clean(hub.ConfigDir + "/" + target.Type + "/" + target.Stage + "/")
hdir := hub.HubDir
parent_dir := filepath.Clean(tdir + "/" + target.Type + "/" + target.Stage + "/")
// create directories if needed // create directories if needed
if target.Installed { if target.Installed {
if target.Tainted { if target.Tainted {
@ -152,6 +150,7 @@ func EnableItem(hub *csconfig.Hub, target Item) (Item, error) {
if target.Local { if target.Local {
return target, fmt.Errorf("%s is local, won't enable", target.Name) return target, fmt.Errorf("%s is local, won't enable", target.Name)
} }
// if it's a collection, check sub-items even if the collection file itself is up-to-date // if it's a collection, check sub-items even if the collection file itself is up-to-date
if target.UpToDate && target.Type != COLLECTIONS { if target.UpToDate && target.Type != COLLECTIONS {
log.Tracef("%s is installed and up-to-date, skip.", target.Name) log.Tracef("%s is installed and up-to-date, skip.", target.Name)
@ -192,8 +191,8 @@ func EnableItem(hub *csconfig.Hub, target Item) (Item, error) {
return target, nil return target, nil
} }
// tdir+target.RemotePath // hub.ConfigDir + target.RemotePath
srcPath, err := filepath.Abs(hdir + "/" + target.RemotePath) srcPath, err := filepath.Abs(hub.HubDir + "/" + target.RemotePath)
if err != nil { if err != nil {
return target, fmt.Errorf("while getting source path: %w", err) return target, fmt.Errorf("while getting source path: %w", err)
} }

View file

@ -15,10 +15,14 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
) )
// the walk/parser_visit function can't receive extra args // the walk/parserVisit function can't receive extra args
var hubdir, installdir string var hubdir, installdir string
func parser_visit(path string, f os.DirEntry, err error) error { func validItemFileName(vname string, fauthor string, fname string) bool {
return (fauthor+"/"+fname == vname+".yaml") || (fauthor+"/"+fname == vname+".yml")
}
func parserVisit(path string, f os.DirEntry, err error) error {
var ( var (
target Item target Item
local bool local bool
@ -122,7 +126,7 @@ func parser_visit(path string, f os.DirEntry, err error) error {
log.Infof("%s is a symlink to %s that doesn't exist, deleting symlink", path, hubpath) log.Infof("%s is a symlink to %s that doesn't exist, deleting symlink", path, hubpath)
// remove the symlink // remove the symlink
if err = os.Remove(path); err != nil { if err = os.Remove(path); err != nil {
return fmt.Errorf("failed to unlink %s: %+v", path, err) return fmt.Errorf("failed to unlink %s: %w", path, err)
} }
return nil return nil
} }
@ -171,8 +175,9 @@ func parser_visit(path string, f os.DirEntry, err error) error {
if fauthor != v.Author { if fauthor != v.Author {
continue continue
} }
// wrong file // wrong file
if CheckName(v.Name, fauthor, fname) { if !validItemFileName(v.Name, fauthor, fname) {
continue continue
} }
@ -180,7 +185,7 @@ func parser_visit(path string, f os.DirEntry, err error) error {
log.Tracef("marking %s as downloaded", v.Name) log.Tracef("marking %s as downloaded", v.Name)
v.Downloaded = true v.Downloaded = true
} }
} else if CheckSuffix(hubpath, v.RemotePath) { } else if !hasPathSuffix(hubpath, v.RemotePath) {
// wrong file // wrong file
// <type>/<stage>/<author>/<name>.yaml // <type>/<stage>/<author>/<name>.yaml
continue continue
@ -234,7 +239,7 @@ func parser_visit(path string, f os.DirEntry, err error) error {
if !match { if !match {
log.Tracef("got tainted match for %s : %s", v.Name, path) log.Tracef("got tainted match for %s : %s", v.Name, path)
skippedTainted += 1 skippedTainted++
// the file and the stage is right, but the hash is wrong, it has been tainted by user // the file and the stage is right, but the hash is wrong, it has been tainted by user
if !inhub { if !inhub {
v.LocalPath = path v.LocalPath = path
@ -297,7 +302,7 @@ func CollecDepsCheck(v *Item) error {
v.Tainted = true v.Tainted = true
} }
return fmt.Errorf("sub collection %s is broken: %s", val.Name, err) return fmt.Errorf("sub collection %s is broken: %w", val.Name, err)
} }
hubIdx[ptrtype][p] = val hubIdx[ptrtype][p] = val
@ -353,7 +358,7 @@ func SyncDir(hub *csconfig.Hub, dir string) (error, []string) {
log.Errorf("failed %s : %s", cpath, err) log.Errorf("failed %s : %s", cpath, err)
} }
err = filepath.WalkDir(cpath, parser_visit) err = filepath.WalkDir(cpath, parserVisit)
if err != nil { if err != nil {
return err, warnings return err, warnings
} }
@ -387,12 +392,12 @@ func LocalSync(hub *csconfig.Hub) (error, []string) {
err, warnings := SyncDir(hub, hub.ConfigDir) err, warnings := SyncDir(hub, hub.ConfigDir)
if err != nil { if err != nil {
return fmt.Errorf("failed to scan %s : %s", hub.ConfigDir, err), warnings return fmt.Errorf("failed to scan %s: %w", hub.ConfigDir, err), warnings
} }
err, _ = SyncDir(hub, hub.HubDir) err, _ = SyncDir(hub, hub.HubDir)
if err != nil { if err != nil {
return fmt.Errorf("failed to scan %s : %s", hub.HubDir, err), warnings return fmt.Errorf("failed to scan %s: %w", hub.HubDir, err), warnings
} }
return nil, warnings return nil, warnings
@ -413,7 +418,7 @@ func GetHubIdx(hub *csconfig.Hub) error {
ret, err := LoadPkgIndex(bidx) ret, err := LoadPkgIndex(bidx)
if err != nil { if err != nil {
if !errors.Is(err, ReferenceMissingError) { if !errors.Is(err, ReferenceMissingError) {
log.Fatalf("Unable to load existing index : %v.", err) return fmt.Errorf("unable to load existing index: %w", err)
} }
return err return err
@ -423,7 +428,7 @@ func GetHubIdx(hub *csconfig.Hub) error {
err, _ = LocalSync(hub) err, _ = LocalSync(hub)
if err != nil { if err != nil {
log.Fatalf("Failed to sync Hub index with local deployment : %v", err) return fmt.Errorf("failed to sync Hub index with local deployment : %w", err)
} }
return nil return nil
@ -438,7 +443,7 @@ func LoadPkgIndex(buff []byte) (map[string]map[string]Item, error) {
) )
if err = json.Unmarshal(buff, &RawIndex); err != nil { if err = json.Unmarshal(buff, &RawIndex); err != nil {
return nil, fmt.Errorf("failed to unmarshal index : %v", err) return nil, fmt.Errorf("failed to unmarshal index: %w", err)
} }
log.Debugf("%d item types in hub index", len(ItemTypes)) log.Debugf("%d item types in hub index", len(ItemTypes))

View file

@ -1,23 +0,0 @@
package cwhub
import (
"path/filepath"
"strings"
)
func CheckSuffix(hubpath string, remotePath string) bool {
newPath := filepath.ToSlash(hubpath)
if !strings.HasSuffix(newPath, remotePath) {
return true
} else {
return false
}
}
func CheckName(vname string, fauthor string, fname string) bool {
if vname+".yaml" != fauthor+"/"+fname && vname+".yml" != fauthor+"/"+fname {
return true
} else {
return false
}
}

View file

@ -5,20 +5,6 @@ package cwhub
import "strings" import "strings"
const PathSeparator = "/" func hasPathSuffix(hubpath string, remotePath string) bool {
return strings.HasSuffix(hubpath, remotePath)
func CheckSuffix(hubpath string, remotePath string) bool {
if !strings.HasSuffix(hubpath, remotePath) {
return true
} else {
return false
}
}
func CheckName(vname string, fauthor string, fname string) bool {
if vname+".yaml" != fauthor+"/"+fname && vname+".yml" != fauthor+"/"+fname {
return true
} else {
return false
}
} }

View file

@ -0,0 +1,11 @@
package cwhub
import (
"path/filepath"
"strings"
)
func hasPathSuffix(hubpath string, remotePath string) bool {
newPath := filepath.ToSlash(hubpath)
return strings.HasSuffix(newPath, remotePath)
}