Refact cwhub: simplify enable/disable/download (#2597)

* Extract methods createInstallLink(), removeInstallLink(), simplify
 - the result of filepath.Join is already Cleaned
 - no need to log the creation of parentDir
 - filepath.Abs() only returns error if the current working directory has been removed
* Extract method Item.fetch()
* Replace Create() + Write() -> WriteFile()
This commit is contained in:
mmetc 2023-11-16 13:05:55 +01:00 committed by GitHub
parent d9b0d440bf
commit 65473d4e05
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 139 additions and 143 deletions

View file

@ -48,14 +48,9 @@ func testHub(t *testing.T, update bool) *Hub {
err = os.MkdirAll(local.InstallDataDir, 0o700)
require.NoError(t, err)
index, err := os.Create(local.HubIndexFile)
err = os.WriteFile(local.HubIndexFile, []byte("{}"), 0o644)
require.NoError(t, err)
_, err = index.WriteString(`{}`)
require.NoError(t, err)
index.Close()
t.Cleanup(func() {
os.RemoveAll(tmpDir)
})

View file

@ -10,12 +10,42 @@ import (
log "github.com/sirupsen/logrus"
)
// enable creates a symlink between actual config file at hub.HubDir and hub.ConfigDir
// Handles collections recursively
func (i *Item) enable() error {
parentDir := filepath.Clean(i.hub.local.InstallDir + "/" + i.Type + "/" + i.Stage + "/")
// installLink returns the location of the symlink to the actual config file (eg. /etc/crowdsec/collections/xyz.yaml)
func (i *Item) installLink() string {
return filepath.Join(i.hub.local.InstallDir, i.Type, i.Stage, i.FileName)
}
// create directories if needed
// makeLink creates a symlink between the actual config file at hub.HubDir and hub.ConfigDir
func (i *Item) createInstallLink() error {
dest, err := filepath.Abs(i.installLink())
if err != nil {
return err
}
destDir := filepath.Dir(dest)
if err = os.MkdirAll(destDir, os.ModePerm); err != nil {
return fmt.Errorf("while creating %s: %w", destDir, err)
}
if _, err = os.Lstat(dest); !os.IsNotExist(err) {
log.Infof("%s already exists.", dest)
return nil
}
src, err := filepath.Abs(filepath.Join(i.hub.local.HubDir, i.RemotePath))
if err != nil {
return err
}
if err = os.Symlink(src, dest); err != nil {
return fmt.Errorf("while creating symlink from %s to %s: %w", src, dest, err)
}
return nil
}
// enable enables the item by creating a symlink to the downloaded content, and also enables sub-items
func (i *Item) enable() error {
if i.Installed {
if i.Tainted {
return fmt.Errorf("%s is tainted, won't enable unless --force", i.Name)
@ -32,40 +62,14 @@ func (i *Item) enable() error {
}
}
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
log.Infof("%s doesn't exist, create", parentDir)
if err = os.MkdirAll(parentDir, os.ModePerm); err != nil {
return fmt.Errorf("while creating directory: %w", err)
}
}
// install sub-items if any
for _, sub := range i.SubItems() {
if err := sub.enable(); err != nil {
return fmt.Errorf("while installing %s: %w", sub.Name, err)
}
}
// check if file already exists where it should in configdir (eg /etc/crowdsec/collections/)
if _, err := os.Lstat(parentDir + "/" + i.FileName); !os.IsNotExist(err) {
log.Infof("%s already exists.", parentDir+"/"+i.FileName)
return nil
}
// hub.ConfigDir + target.RemotePath
srcPath, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath)
if err != nil {
return fmt.Errorf("while getting source path: %w", err)
}
dstPath, err := filepath.Abs(parentDir + "/" + i.FileName)
if err != nil {
return fmt.Errorf("while getting destination path: %w", err)
}
if err = os.Symlink(srcPath, dstPath); err != nil {
return fmt.Errorf("while creating symlink from %s to %s: %w", srcPath, dstPath, err)
if err := i.createInstallLink(); err != nil {
return err
}
log.Infof("Enabled %s: %s", i.Type, i.Name)
@ -81,12 +85,11 @@ func (i *Item) purge() error {
return nil
}
itempath := i.hub.local.HubDir + "/" + i.RemotePath
src := filepath.Join(i.hub.local.HubDir, i.RemotePath)
// disable hub file
if err := os.Remove(itempath); err != nil {
if err := os.Remove(src); err != nil {
if os.IsNotExist(err) {
log.Debugf("%s doesn't exist, no need to remove", itempath)
log.Debugf("%s doesn't exist, no need to remove", src)
return nil
}
@ -94,7 +97,48 @@ func (i *Item) purge() error {
}
i.Downloaded = false
log.Infof("Removed source file [%s]: %s", i.Name, itempath)
log.Infof("Removed source file [%s]: %s", i.Name, src)
return nil
}
func (i *Item) removeInstallLink() error {
syml, err := filepath.Abs(i.installLink())
if err != nil {
return err
}
stat, err := os.Lstat(syml)
if err != nil {
return err
}
// if it's managed by hub, it's a symlink to csconfig.GConfig.hub.HubDir / ...
if stat.Mode()&os.ModeSymlink == 0 {
log.Warningf("%s (%s) isn't a symlink, can't disable", i.Name, syml)
return fmt.Errorf("%s isn't managed by hub", i.Name)
}
hubpath, err := os.Readlink(syml)
if err != nil {
return fmt.Errorf("while reading symlink: %w", err)
}
src, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath)
if err != nil {
return err
}
if hubpath != src {
log.Warningf("%s (%s) isn't a symlink to %s", i.Name, syml, src)
return fmt.Errorf("%s isn't managed by hub", i.Name)
}
if err := os.Remove(syml); err != nil {
return fmt.Errorf("while removing symlink: %w", err)
}
log.Infof("Removed symlink [%s]: %s", i.Name, syml)
return nil
}
@ -111,6 +155,7 @@ func (i *Item) disable(purge bool, force bool) error {
}
for _, sub := range i.SubItems() {
// TODO XXX: if the other collection(s) are direct or indirect dependencies of the current one, it's good to go
if len(sub.BelongsToCollections) > 1 {
log.Infof("%s was not removed because it belongs to another collection", sub.Name)
continue
@ -126,48 +171,13 @@ func (i *Item) disable(purge bool, force bool) error {
return nil
}
syml, err := filepath.Abs(i.hub.local.InstallDir + "/" + i.Type + "/" + i.Stage + "/" + i.FileName)
if err != nil {
return err
}
stat, err := os.Lstat(syml)
err := i.removeInstallLink()
if os.IsNotExist(err) {
// we only accept to "delete" non existing items if it's a forced purge
if !purge && !force {
return fmt.Errorf("can't delete %s: %s doesn't exist", i.Name, syml)
return fmt.Errorf("can't disable %s: %s doesn't exist", i.Name, i.installLink())
}
} else {
// if it's managed by hub, it's a symlink to csconfig.GConfig.hub.HubDir / ...
if stat.Mode()&os.ModeSymlink == 0 {
log.Warningf("%s (%s) isn't a symlink, can't disable", i.Name, syml)
return fmt.Errorf("%s isn't managed by hub", i.Name)
}
hubpath, err := os.Readlink(syml)
if err != nil {
return fmt.Errorf("while reading symlink: %w", err)
}
absPath, err := filepath.Abs(i.hub.local.HubDir + "/" + i.RemotePath)
if err != nil {
return fmt.Errorf("while abs path: %w", err)
}
if hubpath != absPath {
log.Warningf("%s (%s) isn't a symlink to %s", i.Name, syml, absPath)
return fmt.Errorf("%s isn't managed by hub", i.Name)
}
if err := os.Remove(syml); err != nil {
if os.IsNotExist(err) {
log.Debugf("%s doesn't exist, no need to remove", syml)
return nil
}
return fmt.Errorf("while removing symlink: %w", err)
}
log.Infof("Removed symlink [%s]: %s", i.Name, syml)
} else if err != nil {
return err
}
i.Installed = false

View file

@ -161,14 +161,46 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) error {
return nil
}
func (i *Item) download(overwrite bool) error {
// fetch downloads the item from the hub, verifies the hash and returns the body
func (i *Item) fetch() ([]byte, error) {
url, err := i.hub.remote.urlTo(i.RemotePath)
if err != nil {
return fmt.Errorf("failed to build hub item request: %w", err)
return nil, fmt.Errorf("failed to build hub item request: %w", err)
}
tdir := i.hub.local.HubDir
resp, err := hubClient.Get(url)
if err != nil {
return nil, fmt.Errorf("while downloading %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("bad http code %d for %s", resp.StatusCode, url)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("while downloading %s: %w", url, err)
}
hash := sha256.New()
if _, err = hash.Write(body); err != nil {
return nil, fmt.Errorf("while hashing %s: %w", i.Name, err)
}
meow := hex.EncodeToString(hash.Sum(nil))
if meow != i.Versions[i.Version].Digest {
log.Errorf("Downloaded version doesn't match index, please 'hub update'")
log.Debugf("got %s, expected %s", meow, i.Versions[i.Version].Digest)
return nil, fmt.Errorf("invalid download hash for %s", i.Name)
}
return body, nil
}
// download downloads the item from the hub and writes it to the hub directory
func (i *Item) download(overwrite bool) error {
// if user didn't --force, don't overwrite local, tainted, up-to-date files
if !overwrite {
if i.Tainted {
@ -182,56 +214,29 @@ func (i *Item) download(overwrite bool) error {
}
}
resp, err := hubClient.Get(url)
body, err := i.fetch()
if err != nil {
return fmt.Errorf("while downloading %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad http code %d for %s", resp.StatusCode, url)
return err
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("while downloading %s: %w", url, err)
}
hash := sha256.New()
if _, err = hash.Write(body); err != nil {
return fmt.Errorf("while hashing %s: %w", i.Name, err)
}
meow := hex.EncodeToString(hash.Sum(nil))
if meow != i.Versions[i.Version].Digest {
log.Errorf("Downloaded version doesn't match index, please 'hub update'")
log.Debugf("got %s, expected %s", meow, i.Versions[i.Version].Digest)
return fmt.Errorf("invalid download hash for %s", i.Name)
}
tdir := i.hub.local.HubDir
//all good, install
//check if parent dir exists
tmpdirs := strings.Split(tdir+"/"+i.RemotePath, "/")
parentDir := strings.Join(tmpdirs[:len(tmpdirs)-1], "/")
// ensure that target file is within target dir
finalPath, err := filepath.Abs(tdir + "/" + i.RemotePath)
finalPath, err := filepath.Abs(filepath.Join(tdir, i.RemotePath))
if err != nil {
return fmt.Errorf("filepath.Abs error on %s: %w", tdir+"/"+i.RemotePath, err)
return err
}
// ensure that target file is within target dir
if !strings.HasPrefix(finalPath, tdir) {
return fmt.Errorf("path %s escapes %s, abort", i.RemotePath, tdir)
}
// check dir
if _, err = os.Stat(parentDir); os.IsNotExist(err) {
log.Debugf("%s doesn't exist, create", parentDir)
parentDir := filepath.Dir(finalPath)
if err = os.MkdirAll(parentDir, os.ModePerm); err != nil {
return fmt.Errorf("while creating parent directories: %w", err)
}
if err = os.MkdirAll(parentDir, os.ModePerm); err != nil {
return fmt.Errorf("while creating %s: %w", parentDir, err)
}
// check actual file
@ -242,15 +247,8 @@ func (i *Item) download(overwrite bool) error {
log.Infof("%s: OK", i.Name)
}
f, err := os.Create(tdir + "/" + i.RemotePath)
if err != nil {
return fmt.Errorf("while opening file: %w", err)
}
defer f.Close()
_, err = f.Write(body)
if err != nil {
return fmt.Errorf("while writing file: %w", err)
if err = os.WriteFile(finalPath, body, 0o644); err != nil {
return fmt.Errorf("while writing %s: %w", finalPath, err)
}
i.Downloaded = true

View file

@ -69,5 +69,5 @@ func TestDownloadIndex(t *testing.T) {
}
err = hub.remote.downloadIndex("/does/not/exist/index.json")
cstest.RequireErrorContains(t, err, "while opening hub index file: open /does/not/exist/index.json:")
cstest.RequireErrorContains(t, err, "failed to write hub index: open /does/not/exist/index.json:")
}

View file

@ -250,7 +250,7 @@ func (i *Item) versionStatus() int {
}
// validPath returns true if the (relative) path is allowed for the item
// dirNmae: the directory name (ie. crowdsecurity)
// dirNname: the directory name (ie. crowdsecurity)
// fileName: the filename (ie. apache2-logs.yaml)
func (i *Item) validPath(dirName, fileName string) bool {
return (dirName+"/"+fileName == i.Name+".yaml") || (dirName+"/"+fileName == i.Name+".yml")

View file

@ -70,18 +70,11 @@ func (r *RemoteHubCfg) downloadIndex(localPath string) error {
return nil
}
file, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("while opening hub index file: %w", err)
}
defer file.Close()
wsize, err := file.Write(body)
if err != nil {
return fmt.Errorf("while writing hub index file: %w", err)
if err = os.WriteFile(localPath, body, 0o644); err != nil {
return fmt.Errorf("failed to write hub index: %w", err)
}
log.Infof("Wrote index to %s, %d bytes", localPath, wsize)
log.Infof("Wrote index to %s, %d bytes", localPath, len(body))
return nil
}