Compare commits

...

3 commits

Author SHA1 Message Date
Marco Mariani eaf8ad57c1 Merge branch 'master' into unix-socket 2023-12-11 15:23:52 +01:00
Boris Rybalkin 24154aefa1 functional test, test version fix 2023-07-13 10:19:34 +01:00
Boris Rybalkin c9d383dbad local api unix socket support 2023-07-13 09:39:54 +01:00
10 changed files with 345 additions and 56 deletions

2
.gitignore vendored
View file

@ -57,3 +57,5 @@ msi
__pycache__
*.py[cod]
*.egg-info
.idea

View file

@ -97,21 +97,7 @@ func runLapiRegister(cmd *cobra.Command, args []string) error {
}
}
password := strfmt.Password(generatePassword(passwordLength))
if apiURL == "" {
if csConfig.API.Client == nil || csConfig.API.Client.Credentials == nil || csConfig.API.Client.Credentials.URL == "" {
return fmt.Errorf("no Local API URL. Please provide it in your configuration or with the -u parameter")
}
apiURL = csConfig.API.Client.Credentials.URL
}
/*URL needs to end with /, but user doesn't care*/
if !strings.HasSuffix(apiURL, "/") {
apiURL += "/"
}
/*URL needs to start with http://, but user doesn't care*/
if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") {
apiURL = "http://" + apiURL
}
apiurl, err := url.Parse(apiURL)
apiurl, err := prepareApiURL(csConfig.API.Client, apiURL)
if err != nil {
return fmt.Errorf("parsing api url: %w", err)
}
@ -160,6 +146,24 @@ func runLapiRegister(cmd *cobra.Command, args []string) error {
return nil
}
func prepareApiURL(clientConfig *csconfig.LocalApiClientCfg, apiURL string) (*url.URL, error) {
if apiURL == "" {
if clientConfig == nil || clientConfig.Credentials == nil || clientConfig.Credentials.URL == "" {
return nil, fmt.Errorf("no Local API URL. Please provide it in your configuration or with the -u parameter")
}
apiURL = clientConfig.Credentials.URL
}
if !strings.HasSuffix(apiURL, "/") {
apiURL += "/"
}
if !strings.HasPrefix(apiURL, "http://") && !strings.HasPrefix(apiURL, "https://") && !strings.HasPrefix(apiURL, "/") {
apiURL = "http://" + apiURL
}
return url.Parse(apiURL)
}
func NewLapiStatusCmd() *cobra.Command {
cmdLapiStatus := &cobra.Command{
Use: "status",

View file

@ -0,0 +1,58 @@
package main
import (
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/stretchr/testify/assert"
"testing"
)
func TestPrepareApiURl_NoProtocol(t *testing.T) {
url, err := prepareApiURl(nil, "localhost:81")
assert.NoError(t, err)
assert.Equal(t, "http://localhost:81/", url.String())
}
func TestPrepareApiURl_Http(t *testing.T) {
url, err := prepareApiURl(nil, "http://localhost:81")
assert.NoError(t, err)
assert.Equal(t, "http://localhost:81/", url.String())
}
func TestPrepareApiURl_Https(t *testing.T) {
url, err := prepareApiURl(nil, "https://localhost:81")
assert.NoError(t, err)
assert.Equal(t, "https://localhost:81/", url.String())
}
func TestPrepareApiURl_UnixSocket(t *testing.T) {
url, err := prepareApiURl(nil, "/path/socket")
assert.NoError(t, err)
assert.Equal(t, "/path/socket/", url.String())
}
func TestPrepareApiURl_Empty(t *testing.T) {
_, err := prepareApiURl(nil, "")
assert.Error(t, err)
}
func TestPrepareApiURl_Empty_ConfigOverride(t *testing.T) {
url, err := prepareApiURl(&csconfig.LocalApiClientCfg{
Credentials: &csconfig.ApiCredentialsCfg{
URL: "localhost:80",
},
}, "")
assert.NoError(t, err)
assert.Equal(t, "http://localhost:80/", url.String())
}

View file

@ -317,7 +317,7 @@ func (cli cliMachines) add(cmd *cobra.Command, args []string) error {
if csConfig.API.Client != nil && csConfig.API.Client.Credentials != nil && csConfig.API.Client.Credentials.URL != "" {
apiURL = csConfig.API.Client.Credentials.URL
} else if csConfig.API.Server != nil && csConfig.API.Server.ListenURI != "" {
apiURL = "http://" + csConfig.API.Server.ListenURI
apiURL = csConfig.API.Server.ClientUrl()
} else {
return fmt.Errorf("unable to dump an api URL. Please provide it in your configuration or with the -u parameter")
}

View file

@ -183,9 +183,14 @@ func (t *JWTTransport) refreshJwtToken() error {
return fmt.Errorf("could not create request: %w", err)
}
req.Header.Add("Content-Type", "application/json")
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}
client := &http.Client{
Transport: &retryRoundTripper{
next: http.DefaultTransport,
next: transport,
maxAttempts: 5,
withBackOff: true,
retryStatusCodes: []int{http.StatusTooManyRequests, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusInternalServerError},

View file

@ -6,11 +6,12 @@ import (
"crypto/x509"
"encoding/json"
"fmt"
"github.com/crowdsecurity/crowdsec/pkg/models"
"io"
"net"
"net/http"
"net/url"
"github.com/crowdsecurity/crowdsec/pkg/models"
"strings"
)
var (
@ -52,11 +53,16 @@ func NewClient(config *Config) (*ApiClient, error) {
MachineID: &config.MachineID,
Password: &config.Password,
Scenarios: config.Scenarios,
URL: config.URL,
UserAgent: config.UserAgent,
VersionPrefix: config.VersionPrefix,
UpdateScenario: config.UpdateScenario,
}
transport, baseUrl := CreateTransport(config.URL)
if transport != nil {
t.Transport = transport
}
t.URL = baseUrl
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
tlsconfig.RootCAs = CaCertPool
if Cert != nil {
@ -65,7 +71,7 @@ func NewClient(config *Config) (*ApiClient, error) {
if ht, ok := http.DefaultTransport.(*http.Transport); ok {
ht.TLSClientConfig = &tlsconfig
}
c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL}
c := &ApiClient{client: t.Client(), BaseURL: baseUrl, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix, PapiURL: config.PapiURL}
c.common.client = c
c.Decisions = (*DecisionsService)(&c.common)
c.Alerts = (*AlertsService)(&c.common)
@ -79,19 +85,24 @@ func NewClient(config *Config) (*ApiClient, error) {
}
func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *http.Client) (*ApiClient, error) {
transport, baseUrl := CreateTransport(URL)
if client == nil {
client = &http.Client{}
if ht, ok := http.DefaultTransport.(*http.Transport); ok {
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
tlsconfig.RootCAs = CaCertPool
if Cert != nil {
tlsconfig.Certificates = []tls.Certificate{*Cert}
if transport != nil {
client.Transport = transport
} else {
if ht, ok := http.DefaultTransport.(*http.Transport); ok {
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
tlsconfig.RootCAs = CaCertPool
if Cert != nil {
tlsconfig.Certificates = []tls.Certificate{*Cert}
}
ht.TLSClientConfig = &tlsconfig
client.Transport = ht
}
ht.TLSClientConfig = &tlsconfig
client.Transport = ht
}
}
c := &ApiClient{client: client, BaseURL: URL, UserAgent: userAgent, URLPrefix: prefix}
c := &ApiClient{client: client, BaseURL: baseUrl, UserAgent: userAgent, URLPrefix: prefix}
c.common.client = c
c.Decisions = (*DecisionsService)(&c.common)
c.Alerts = (*AlertsService)(&c.common)
@ -105,16 +116,24 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt
}
func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
transport, baseUrl := CreateTransport(config.URL)
if client == nil {
client = &http.Client{}
if transport != nil {
client.Transport = transport
} else {
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
if Cert != nil {
tlsconfig.RootCAs = CaCertPool
tlsconfig.Certificates = []tls.Certificate{*Cert}
}
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig
}
} else if client.Transport == nil && transport != nil {
client.Transport = transport
}
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
if Cert != nil {
tlsconfig.RootCAs = CaCertPool
tlsconfig.Certificates = []tls.Certificate{*Cert}
}
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig
c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix}
c := &ApiClient{client: client, BaseURL: baseUrl, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix}
c.common.client = c
c.Decisions = (*DecisionsService)(&c.common)
c.Alerts = (*AlertsService)(&c.common)
@ -132,6 +151,26 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
}
func CreateTransport(url *url.URL) (*http.Transport, *url.URL) {
urlString := url.String()
if strings.HasPrefix(urlString, "/") {
ToUnixSocketUrl(url)
return &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.Dial("unix", strings.TrimSuffix(urlString, "/"))
},
}, url
} else {
return nil, url
}
}
func ToUnixSocketUrl(url *url.URL) {
url.Path = "/"
url.Host = "unix"
url.Scheme = "http"
}
type Response struct {
Response *http.Response
//add our pagination stuff

View file

@ -3,9 +3,11 @@ package apiclient
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"path"
"runtime"
"testing"
@ -32,12 +34,26 @@ func setupWithPrefix(urlPrefix string) (mux *http.ServeMux, serverURL string, te
apiHandler := http.NewServeMux()
apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux))
// server is a test HTTP server used to provide mock API responses.
server := httptest.NewServer(apiHandler)
return mux, server.URL, server.Close
}
func setupUnixSocketWithPrefix(socket string, urlPrefix string) (mux *http.ServeMux, serverURL string, teardown func()) {
mux = http.NewServeMux()
baseURLPath := "/" + urlPrefix
apiHandler := http.NewServeMux()
apiHandler.Handle(baseURLPath+"/", http.StripPrefix(baseURLPath, mux))
server := httptest.NewUnstartedServer(apiHandler)
l, _ := net.Listen("unix", socket)
_ = server.Listener.Close()
server.Listener = l
server.Start()
return mux, socket, server.Close
}
func testMethod(t *testing.T, r *http.Request, want string) {
t.Helper()
if got := r.Method; got != want {
@ -82,6 +98,49 @@ func TestNewClientOk(t *testing.T) {
}
}
func TestNewClientOk_UnixSocket(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping on windows")
}
tmpDir := t.TempDir()
socket := path.Join(tmpDir, "socket")
mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1")
defer teardown()
apiURL, err := url.Parse(urlx)
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
client, err := NewClient(&Config{
MachineID: "test_login",
Password: "test_password",
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
VersionPrefix: "v1",
})
if err != nil {
t.Fatalf("new api client: %s", err)
}
/*mock login*/
mux.HandleFunc("/watchers/login", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "GET")
w.WriteHeader(http.StatusOK)
})
_, resp, err := client.Alerts.List(context.Background(), AlertsListOpts{})
if err != nil {
t.Fatalf("test Unable to list alerts : %+v", err)
}
if resp.Response.StatusCode != http.StatusOK {
t.Fatalf("Alerts.List returned status: %d, want %d", resp.Response.StatusCode, http.StatusCreated)
}
}
func TestNewClientKo(t *testing.T) {
mux, urlx, teardown := setup()
defer teardown()
@ -135,6 +194,31 @@ func TestNewDefaultClient(t *testing.T) {
log.Printf("err-> %s", err)
}
func TestNewDefaultClient_UnixSocket(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping on windows")
}
tmpDir := t.TempDir()
socket := path.Join(tmpDir, "socket")
mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1")
defer teardown()
apiURL, err := url.Parse(urlx)
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
client, err := NewDefaultClient(apiURL, "/v1", "", nil)
if err != nil {
t.Fatalf("new api client: %s", err)
}
mux.HandleFunc("/alerts", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"code": 401, "message" : "brr"}`))
})
_, _, err = client.Alerts.List(context.Background(), AlertsListOpts{})
assert.Contains(t, err.Error(), `performing request: API error: brr`)
log.Printf("err-> %s", err)
}
func TestNewClientRegisterKO(t *testing.T) {
apiURL, err := url.Parse("http://127.0.0.1:4242/")
if err != nil {
@ -183,6 +267,40 @@ func TestNewClientRegisterOK(t *testing.T) {
log.Printf("->%T", client)
}
func TestNewClientRegisterOK_UnixSocket(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping on windows")
}
log.SetLevel(log.TraceLevel)
tmpDir := t.TempDir()
socket := path.Join(tmpDir, "socket")
mux, urlx, teardown := setupUnixSocketWithPrefix(socket, "v1")
defer teardown()
/*mock login*/
mux.HandleFunc("/watchers", func(w http.ResponseWriter, r *http.Request) {
testMethod(t, r, "POST")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"code": 200, "expire": "2030-01-02T15:04:05Z", "token": "oklol"}`))
})
apiURL, err := url.Parse(urlx)
if err != nil {
t.Fatalf("parsing api url: %s", apiURL)
}
client, err := RegisterClient(&Config{
MachineID: "test_login",
Password: "test_password",
UserAgent: fmt.Sprintf("crowdsec/%s", version.String()),
URL: apiURL,
VersionPrefix: "v1",
}, &http.Client{})
if err != nil {
t.Fatalf("while registering client : %s", err)
}
log.Printf("->%T", client)
}
func TestNewClientBadAnswer(t *testing.T) {
log.SetLevel(log.TraceLevel)
mux, urlx, teardown := setup()

View file

@ -38,6 +38,7 @@ var (
type APIServer struct {
URL string
isUnixSocket bool
TLS *csconfig.TLSCfg
dbClient *database.Client
logFile string
@ -243,6 +244,7 @@ func NewServer(config *csconfig.LocalApiServerCfg) (*APIServer, error) {
return &APIServer{
URL: config.ListenURI,
isUnixSocket: config.IsUnixSocket(),
TLS: config.TLS,
logFile: logFile,
dbClient: dbClient,
@ -390,19 +392,30 @@ func (s *APIServer) Run(apiReady chan bool) error {
go func() {
apiReady <- true
log.Infof("CrowdSec Local API listening on %s", s.URL)
if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") {
if s.TLS.KeyFilePath == "" {
log.Fatalf("while serving local API: %v", errors.New("missing TLS key file"))
} else if s.TLS.CertFilePath == "" {
log.Fatalf("while serving local API: %v", errors.New("missing TLS cert file"))
if s.isUnixSocket {
_ = os.RemoveAll(s.URL)
listener, err := net.Listen("unix", s.URL)
if err != nil {
log.Fatalf("while creating unix listener: %v", err)
}
if err := s.httpServer.ListenAndServeTLS(s.TLS.CertFilePath, s.TLS.KeyFilePath); err != nil {
log.Fatalf("while serving local API: %v", err)
if err = s.httpServer.Serve(listener); err != http.ErrServerClosed {
log.Fatalf("while serving local API (unix socket): %v", err)
}
} else {
if err := s.httpServer.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("while serving local API: %v", err)
if s.TLS != nil && (s.TLS.CertFilePath != "" || s.TLS.KeyFilePath != "") {
if s.TLS.KeyFilePath == "" {
log.Fatalf("while serving local API: %v", errors.New("missing TLS key file"))
} else if s.TLS.CertFilePath == "" {
log.Fatalf("while serving local API: %v", errors.New("missing TLS cert file"))
}
if err := s.httpServer.ListenAndServeTLS(s.TLS.CertFilePath, s.TLS.KeyFilePath); err != nil {
log.Fatalf("while serving local API (tcp tls): %v", err)
}
} else {
if err = s.httpServer.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("while serving local API (tcp): %v", err)
}
}
}
}()

View file

@ -156,9 +156,9 @@ func (l *LocalApiClientCfg) Load() error {
return nil
}
func (lapiCfg *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) {
func (c *LocalApiServerCfg) GetTrustedIPs() ([]net.IPNet, error) {
trustedIPs := make([]net.IPNet, 0)
for _, ip := range lapiCfg.TrustedIPs {
for _, ip := range c.TrustedIPs {
cidr := toValidCIDR(ip)
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
@ -212,6 +212,17 @@ type LocalApiServerCfg struct {
CapiWhitelists *CapiWhitelist `yaml:"-"`
}
func (c *LocalApiServerCfg) IsUnixSocket() bool {
return strings.HasPrefix(c.ListenURI, "/")
}
func (c *LocalApiServerCfg) ClientUrl() string {
if c.IsUnixSocket() {
return c.ListenURI
}
return fmt.Sprintf("http://%s", c.ListenURI)
}
type TLSCfg struct {
CertFilePath string `yaml:"cert_file"`
KeyFilePath string `yaml:"key_file"`
@ -360,25 +371,25 @@ func parseCapiWhitelists(fd io.Reader) (*CapiWhitelist, error) {
return ret, nil
}
func (s *LocalApiServerCfg) LoadCapiWhitelists() error {
if s.CapiWhitelistsPath == "" {
func (c *LocalApiServerCfg) LoadCapiWhitelists() error {
if c.CapiWhitelistsPath == "" {
return nil
}
if _, err := os.Stat(s.CapiWhitelistsPath); os.IsNotExist(err) {
return fmt.Errorf("capi whitelist file '%s' does not exist", s.CapiWhitelistsPath)
if _, err := os.Stat(c.CapiWhitelistsPath); os.IsNotExist(err) {
return fmt.Errorf("capi whitelist file '%s' does not exist", c.CapiWhitelistsPath)
}
fd, err := os.Open(s.CapiWhitelistsPath)
fd, err := os.Open(c.CapiWhitelistsPath)
if err != nil {
return fmt.Errorf("while opening capi whitelist file: %s", err)
}
defer fd.Close()
s.CapiWhitelists, err = parseCapiWhitelists(fd)
c.CapiWhitelists, err = parseCapiWhitelists(fd)
if err != nil {
return fmt.Errorf("while parsing capi whitelist file '%s': %w", s.CapiWhitelistsPath, err)
return fmt.Errorf("while parsing capi whitelist file '%s': %w", c.CapiWhitelistsPath, err)
}
return nil

39
test/bats/09_socket.bats Normal file
View file

@ -0,0 +1,39 @@
#!/usr/bin/env bats
# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si:
set -u
setup_file() {
load "../lib/setup_file.sh"
}
teardown_file() {
load "../lib/teardown_file.sh"
}
setup() {
load "../lib/setup.sh"
load "../lib/bats-file/load.bash"
./instance-data load
}
teardown() {
./instance-crowdsec stop
}
#----------
@test "cscli - connects with socket" {
sockdir=$(TMPDIR="${BATS_TEST_TMPDIR}" mktemp -u)
mkdir -p "${sockdir}"
export socket="${sockdir}/crowdsec_api.sock"
config_set ".api.server.listen_uri=strenv(socket)"
LOCAL_API_CREDENTIALS=$(config_get '.api.client.credentials_path')
config_set "${LOCAL_API_CREDENTIALS}" ".url=strenv(socket)"
./instance-crowdsec start
rune -0 cscli lapi status
assert_stderr --partial "You can successfully interact with Local API (LAPI)"
}