pkg/cwhub: download data assets to temporary files to avoid partial fetch (#2879)

This commit is contained in:
mmetc 2024-03-08 10:55:30 +01:00 committed by GitHub
parent 1eab943ec2
commit 6c5e8afde9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 33 additions and 17 deletions

View file

@ -6,6 +6,7 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"path/filepath"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -31,19 +32,32 @@ func downloadFile(url string, destPath string) error {
return fmt.Errorf("bad http code %d for %s", resp.StatusCode, url) return fmt.Errorf("bad http code %d for %s", resp.StatusCode, url)
} }
file, err := os.Create(destPath) tmpFile, err := os.CreateTemp(filepath.Dir(destPath), filepath.Base(destPath)+".*.tmp")
if err != nil { if err != nil {
return err return err
} }
defer file.Close()
tmpFileName := tmpFile.Name()
defer func() {
tmpFile.Close()
os.Remove(tmpFileName)
}()
// avoid reading the whole file in memory // avoid reading the whole file in memory
_, err = io.Copy(file, resp.Body) _, err = io.Copy(tmpFile, resp.Body)
if err != nil { if err != nil {
return err return err
} }
if err = file.Sync(); err != nil { if err = tmpFile.Sync(); err != nil {
return err
}
if err = tmpFile.Close(); err != nil {
return err
}
if err = os.Rename(tmpFileName, destPath); err != nil {
return err return err
} }

View file

@ -16,7 +16,7 @@ func TestDownloadFile(t *testing.T) {
httpmock.Activate() httpmock.Activate()
defer httpmock.DeactivateAndReset() defer httpmock.DeactivateAndReset()
//OK // OK
httpmock.RegisterResponder( httpmock.RegisterResponder(
"GET", "GET",
"https://example.com/xx", "https://example.com/xx",
@ -36,15 +36,15 @@ func TestDownloadFile(t *testing.T) {
assert.Equal(t, "example content oneoneone", string(content)) assert.Equal(t, "example content oneoneone", string(content))
require.NoError(t, err) require.NoError(t, err)
//bad uri // bad uri
err = downloadFile("https://zz.com", examplePath) err = downloadFile("https://zz.com", examplePath)
require.Error(t, err) require.Error(t, err)
//404 // 404
err = downloadFile("https://example.com/x", examplePath) err = downloadFile("https://example.com/x", examplePath)
require.Error(t, err) require.Error(t, err)
//bad target // bad target
err = downloadFile("https://example.com/xx", "") err = downloadFile("https://example.com/xx", "")
require.Error(t, err) require.Error(t, err)
} }

View file

@ -6,7 +6,7 @@ import (
) )
var ( var (
// ErrNilRemoteHub is returned when the remote hub configuration is not provided to the NewHub constructor. // ErrNilRemoteHub is returned when trying to download with a local-only configuration.
ErrNilRemoteHub = errors.New("remote hub configuration is not provided. Please report this issue to the developers") ErrNilRemoteHub = errors.New("remote hub configuration is not provided. Please report this issue to the developers")
) )

View file

@ -3,6 +3,7 @@ package cwhub
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -34,7 +35,7 @@ func (h *Hub) GetDataDir() string {
// All download operations (including updateIndex) return ErrNilRemoteHub if the remote configuration is not set. // All download operations (including updateIndex) return ErrNilRemoteHub if the remote configuration is not set.
func NewHub(local *csconfig.LocalHubCfg, remote *RemoteHubCfg, updateIndex bool, logger *logrus.Logger) (*Hub, error) { func NewHub(local *csconfig.LocalHubCfg, remote *RemoteHubCfg, updateIndex bool, logger *logrus.Logger) (*Hub, error) {
if local == nil { if local == nil {
return nil, fmt.Errorf("no hub configuration found") return nil, errors.New("no hub configuration found")
} }
if logger == nil { if logger == nil {

View file

@ -77,9 +77,9 @@ func (h *Hub) getItemFileInfo(path string, logger *logrus.Logger) (*itemFileInfo
if strings.HasPrefix(path, hubDir) { if strings.HasPrefix(path, hubDir) {
logger.Tracef("in hub dir") logger.Tracef("in hub dir")
//.../hub/parsers/s00-raw/crowdsec/skip-pretag.yaml // .../hub/parsers/s00-raw/crowdsec/skip-pretag.yaml
//.../hub/scenarios/crowdsec/ssh_bf.yaml // .../hub/scenarios/crowdsec/ssh_bf.yaml
//.../hub/profiles/crowdsec/linux.yaml // .../hub/profiles/crowdsec/linux.yaml
if len(subs) < 4 { if len(subs) < 4 {
return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subs)) return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subs))
} }
@ -93,13 +93,14 @@ func (h *Hub) getItemFileInfo(path string, logger *logrus.Logger) (*itemFileInfo
} }
} else if strings.HasPrefix(path, installDir) { // we're in install /etc/crowdsec/<type>/... } else if strings.HasPrefix(path, installDir) { // we're in install /etc/crowdsec/<type>/...
logger.Tracef("in install dir") logger.Tracef("in install dir")
if len(subs) < 3 { if len(subs) < 3 {
return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subs)) return nil, fmt.Errorf("path is too short: %s (%d)", path, len(subs))
} }
///.../config/parser/stage/file.yaml // .../config/parser/stage/file.yaml
///.../config/postoverflow/stage/file.yaml // .../config/postoverflow/stage/file.yaml
///.../config/scenarios/scenar.yaml // .../config/scenarios/scenar.yaml
///.../config/collections/linux.yaml //file is empty // .../config/collections/linux.yaml //file is empty
ret = &itemFileInfo{ ret = &itemFileInfo{
inhub: false, inhub: false,
fname: subs[len(subs)-1], fname: subs[len(subs)-1],