diff --git a/.github/workflows/bats-hub.yml b/.github/workflows/bats-hub.yml index 82cb40ada..e5dab9136 100644 --- a/.github/workflows/bats-hub.yml +++ b/.github/workflows/bats-hub.yml @@ -35,8 +35,12 @@ jobs: - name: "Install bats dependencies" run: | sudo apt install -y -qq build-essential daemonize jq netcat-openbsd - GO111MODULE=on go get github.com/mikefarah/yq/v4 + go install github.com/mikefarah/yq/v4@latest + go install github.com/cloudflare/cfssl/cmd/cfssl@latest + go install github.com/cloudflare/cfssl/cmd/cfssljson@latest sudo cp -u ~/go/bin/yq /usr/local/bin/ + sudo cp -u ~/go/bin/cfssl /usr/local/bin + sudo cp -u ~/go/bin/cfssljson /usr/local/bin - name: "Build crowdsec and fixture" run: make bats-clean bats-build bats-fixture diff --git a/.github/workflows/bats-mysql.yml b/.github/workflows/bats-mysql.yml index 02e428561..0ae0faba2 100644 --- a/.github/workflows/bats-mysql.yml +++ b/.github/workflows/bats-mysql.yml @@ -46,8 +46,12 @@ jobs: - name: "Install bats dependencies" run: | sudo apt install -y -qq build-essential daemonize jq netcat-openbsd - GO111MODULE=on go get github.com/mikefarah/yq/v4 + go install github.com/mikefarah/yq/v4@latest + go install github.com/cloudflare/cfssl/cmd/cfssl@latest + go install github.com/cloudflare/cfssl/cmd/cfssljson@latest sudo cp -u ~/go/bin/yq /usr/local/bin/ + sudo cp -u ~/go/bin/cfssl /usr/local/bin + sudo cp -u ~/go/bin/cfssljson /usr/local/bin - name: "Build crowdsec and fixture" run: make bats-clean bats-build bats-fixture diff --git a/.github/workflows/bats-postgres.yml b/.github/workflows/bats-postgres.yml index a5ad7d64b..8e6e380b1 100644 --- a/.github/workflows/bats-postgres.yml +++ b/.github/workflows/bats-postgres.yml @@ -47,8 +47,12 @@ jobs: - name: "Install bats dependencies" run: | sudo apt install -y -qq build-essential daemonize jq netcat-openbsd - GO111MODULE=on go get github.com/mikefarah/yq/v4 + go install github.com/mikefarah/yq/v4@latest + go install github.com/cloudflare/cfssl/cmd/cfssl@latest + go install github.com/cloudflare/cfssl/cmd/cfssljson@latest sudo cp -u ~/go/bin/yq /usr/local/bin/ + sudo cp -u ~/go/bin/cfssl /usr/local/bin + sudo cp -u ~/go/bin/cfssljson /usr/local/bin - name: "Build crowdsec and fixture (DB_BACKEND: pgx)" run: make clean bats-build bats-fixture diff --git a/.github/workflows/bats-sqlite-coverage.yml b/.github/workflows/bats-sqlite-coverage.yml index b49069b78..4b3c3213f 100644 --- a/.github/workflows/bats-sqlite-coverage.yml +++ b/.github/workflows/bats-sqlite-coverage.yml @@ -32,10 +32,11 @@ jobs: - name: "Install bats dependencies" run: | sudo apt install -y -qq build-essential daemonize jq netcat-openbsd - GO111MODULE=on go get github.com/mikefarah/yq/v4 - sudo cp -u ~/go/bin/yq /usr/local/bin/ + go install github.com/mikefarah/yq/v4@latest + go install github.com/cloudflare/cfssl/cmd/cfssl@latest + go install github.com/cloudflare/cfssl/cmd/cfssljson@latest go install github.com/wadey/gocovmerge@latest - sudo cp -u ~/go/bin/gocovmerge /usr/local/bin/ + sudo cp -u ~/go/bin/yq ~/go/bin/gocovmerge ~/go/bin/cfssl ~/go/bin/cfssljson /usr/local/bin/ - name: "Build crowdsec and fixture" run: TEST_COVERAGE=true make bats-clean bats-build bats-fixture diff --git a/cmd/crowdsec-cli/bouncers.go b/cmd/crowdsec-cli/bouncers.go index 28e1d614f..acd402a22 100644 --- a/cmd/crowdsec-cli/bouncers.go +++ b/cmd/crowdsec-cli/bouncers.go @@ -9,6 +9,7 @@ import ( middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/enescakir/emoji" "github.com/olekukonko/tablewriter" log "github.com/sirupsen/logrus" @@ -65,7 +66,7 @@ Note: This command requires database direct access, so is intended to be run on table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) table.SetAlignment(tablewriter.ALIGN_LEFT) - table.SetHeader([]string{"Name", "IP Address", "Valid", "Last API pull", "Type", "Version"}) + table.SetHeader([]string{"Name", "IP Address", "Valid", "Last API pull", "Type", "Version", "Auth Type"}) for _, b := range blockers { var revoked string if !b.Revoked { @@ -73,7 +74,7 @@ Note: This command requires database direct access, so is intended to be run on } else { revoked = emoji.Prohibited.String() } - table.Append([]string{b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version}) + table.Append([]string{b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version, b.AuthType}) } table.Render() } else if csConfig.Cscli.Output == "json" { @@ -84,7 +85,7 @@ Note: This command requires database direct access, so is intended to be run on fmt.Printf("%s", string(x)) } else if csConfig.Cscli.Output == "raw" { csvwriter := csv.NewWriter(os.Stdout) - err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version"}) + err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version", "auth_type"}) if err != nil { log.Fatalf("failed to write raw header: %s", err) } @@ -95,7 +96,7 @@ Note: This command requires database direct access, so is intended to be run on } else { revoked = "pending" } - err := csvwriter.Write([]string{b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version}) + err := csvwriter.Write([]string{b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version, b.AuthType}) if err != nil { log.Fatalf("failed to write raw: %s", err) } @@ -129,7 +130,7 @@ cscli bouncers add MyBouncerName -k %s`, generatePassword(32)), if err != nil { log.Fatalf("unable to generate api key: %s", err) } - err = dbClient.CreateBouncer(keyName, keyIP, middlewares.HashSHA512(apiKey)) + _, err = dbClient.CreateBouncer(keyName, keyIP, middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) if err != nil { log.Fatalf("unable to create bouncer: %s", err) } diff --git a/cmd/crowdsec-cli/config.go b/cmd/crowdsec-cli/config.go index 495f9292f..b9dd00439 100644 --- a/cmd/crowdsec-cli/config.go +++ b/cmd/crowdsec-cli/config.go @@ -368,6 +368,29 @@ func NewConfigCmd() *cobra.Command { if csConfig.API.Server.TLS.KeyFilePath != "" { fmt.Printf(" - Key File : %s\n", csConfig.API.Server.TLS.KeyFilePath) } + if csConfig.API.Server.TLS.CACertPath != "" { + fmt.Printf(" - CA Cert : %s\n", csConfig.API.Server.TLS.CACertPath) + } + if csConfig.API.Server.TLS.CRLPath != "" { + fmt.Printf(" - CRL : %s\n", csConfig.API.Server.TLS.CRLPath) + } + if csConfig.API.Server.TLS.CacheExpiration != nil { + fmt.Printf(" - Cache Expiration : %s\n", csConfig.API.Server.TLS.CacheExpiration) + } + if csConfig.API.Server.TLS.ClientVerification != "" { + fmt.Printf(" - Client Verification : %s\n", csConfig.API.Server.TLS.ClientVerification) + } + if csConfig.API.Server.TLS.AllowedAgentsOU != nil { + for _, ou := range csConfig.API.Server.TLS.AllowedAgentsOU { + fmt.Printf(" - Allowed Agents OU : %s\n", ou) + } + } + if csConfig.API.Server.TLS.AllowedBouncersOU != nil { + for _, ou := range csConfig.API.Server.TLS.AllowedBouncersOU { + fmt.Printf(" - Allowed Bouncers OU : %s\n", ou) + } + } + } fmt.Printf(" - Trusted IPs: \n") for _, ip := range csConfig.API.Server.TrustedIPs { diff --git a/cmd/crowdsec-cli/machines.go b/cmd/crowdsec-cli/machines.go index c0267f763..b48aa89fd 100644 --- a/cmd/crowdsec-cli/machines.go +++ b/cmd/crowdsec-cli/machines.go @@ -14,6 +14,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/machineid" "github.com/enescakir/emoji" "github.com/go-openapi/strfmt" @@ -140,7 +141,7 @@ Note: This command requires database direct access, so is intended to be run on table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) table.SetAlignment(tablewriter.ALIGN_LEFT) - table.SetHeader([]string{"Name", "IP Address", "Last Update", "Status", "Version", "Last Heartbeat"}) + table.SetHeader([]string{"Name", "IP Address", "Last Update", "Status", "Version", "Auth Type", "Last Heartbeat"}) for _, w := range machines { var validated string if w.IsValidated { @@ -153,7 +154,7 @@ Note: This command requires database direct access, so is intended to be run on if lastHeartBeat > 2*time.Minute { hbDisplay = fmt.Sprintf("%s %s", emoji.Warning.String(), lastHeartBeat.Truncate(time.Second).String()) } - table.Append([]string{w.MachineId, w.IpAddress, w.UpdatedAt.Format(time.RFC3339), validated, w.Version, hbDisplay}) + table.Append([]string{w.MachineId, w.IpAddress, w.UpdatedAt.Format(time.RFC3339), validated, w.Version, w.AuthType, hbDisplay}) } table.Render() } else if csConfig.Cscli.Output == "json" { @@ -164,7 +165,7 @@ Note: This command requires database direct access, so is intended to be run on fmt.Printf("%s", string(x)) } else if csConfig.Cscli.Output == "raw" { csvwriter := csv.NewWriter(os.Stdout) - err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "last_heartbeat"}) + err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat"}) if err != nil { log.Fatalf("failed to write header: %s", err) } @@ -175,7 +176,7 @@ Note: This command requires database direct access, so is intended to be run on } else { validated = "false" } - err := csvwriter.Write([]string{w.MachineId, w.IpAddress, w.UpdatedAt.Format(time.RFC3339), validated, w.Version, time.Now().UTC().Sub(*w.LastHeartbeat).Truncate(time.Second).String()}) + err := csvwriter.Write([]string{w.MachineId, w.IpAddress, w.UpdatedAt.Format(time.RFC3339), validated, w.Version, w.AuthType, time.Now().UTC().Sub(*w.LastHeartbeat).Truncate(time.Second).String()}) if err != nil { log.Fatalf("failed to write raw output : %s", err) } @@ -244,7 +245,7 @@ cscli machines add MyTestMachine --password MyPassword survey.AskOne(qs, &machinePassword) } password := strfmt.Password(machinePassword) - _, err = dbClient.CreateMachine(&machineID, &password, "", true, forceAdd) + _, err = dbClient.CreateMachine(&machineID, &password, "", true, forceAdd, types.PasswordAuthType) if err != nil { log.Fatalf("unable to create machine: %s", err) } diff --git a/pkg/apiclient/client.go b/pkg/apiclient/client.go index 8d48993ac..0f758f1e3 100644 --- a/pkg/apiclient/client.go +++ b/pkg/apiclient/client.go @@ -3,6 +3,7 @@ package apiclient import ( "context" "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io/ioutil" @@ -15,6 +16,8 @@ import ( var ( InsecureSkipVerify = false + Cert *tls.Certificate + CaCertPool *x509.CertPool ) type ApiClient struct { @@ -49,7 +52,12 @@ func NewClient(config *Config) (*ApiClient, error) { VersionPrefix: config.VersionPrefix, UpdateScenario: config.UpdateScenario, } - http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: InsecureSkipVerify} + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} + if Cert != nil { + tlsconfig.RootCAs = CaCertPool + tlsconfig.Certificates = []tls.Certificate{*Cert} + } + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) @@ -66,7 +74,12 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt if client == nil { client = &http.Client{} if ht, ok := http.DefaultTransport.(*http.Transport); ok { - ht.TLSClientConfig = &tls.Config{InsecureSkipVerify: InsecureSkipVerify} + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} + if Cert != nil { + tlsconfig.RootCAs = CaCertPool + tlsconfig.Certificates = []tls.Certificate{*Cert} + } + ht.TLSClientConfig = &tlsconfig client.Transport = ht } } @@ -86,7 +99,12 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) { if client == nil { client = &http.Client{} } - http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: InsecureSkipVerify} + tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify} + if Cert != nil { + tlsconfig.RootCAs = CaCertPool + tlsconfig.Certificates = []tls.Certificate{*Cert} + } + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} c.common.client = c c.Decisions = (*DecisionsService)(&c.common) diff --git a/pkg/apiserver/alerts_test.go b/pkg/apiserver/alerts_test.go index 44437133a..6c98a2fbf 100644 --- a/pkg/apiserver/alerts_test.go +++ b/pkg/apiserver/alerts_test.go @@ -45,17 +45,22 @@ func SetupLAPITest(t *testing.T) LAPI { func (l *LAPI) InsertAlertFromFile(path string) *httptest.ResponseRecorder { alertReader := GetAlertReaderFromFile(path) - return l.RecordResponse("POST", "/v1/alerts", alertReader) + return l.RecordResponse("POST", "/v1/alerts", alertReader, "password") } -func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader) *httptest.ResponseRecorder { +func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder { w := httptest.NewRecorder() req, err := http.NewRequest(verb, url, body) if err != nil { l.t.Fatal(err) } - req.Header.Add("X-Api-Key", l.bouncerKey) - AddAuthHeaders(req, l.loginResp) + if authType == "apikey" { + req.Header.Add("X-Api-Key", l.bouncerKey) + } else if authType == "password" { + AddAuthHeaders(req, l.loginResp) + } else { + l.t.Fatal("auth type not supported") + } l.router.ServeHTTP(w, req) return w } @@ -93,6 +98,7 @@ func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherA if err != nil { return models.WatcherAuthResponse{}, fmt.Errorf("%s", err.Error()) } + return loginResp, nil } @@ -107,13 +113,13 @@ func TestSimulatedAlert(t *testing.T) { alertContent := GetAlertReaderFromFile("./tests/alert_minibulk+simul.json") //exclude decision in simulation mode - w := lapi.RecordResponse("GET", "/v1/alerts?simulated=false", alertContent) + w := lapi.RecordResponse("GET", "/v1/alerts?simulated=false", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) //include decision in simulation mode - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", alertContent) + w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", alertContent, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) @@ -123,14 +129,14 @@ func TestCreateAlert(t *testing.T) { lapi := SetupLAPITest(t) // Create Alert with invalid format - w := lapi.RecordResponse("POST", "/v1/alerts", strings.NewReader("test")) + w := lapi.RecordResponse("POST", "/v1/alerts", strings.NewReader("test"), "password") assert.Equal(t, 400, w.Code) assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) // Create Alert with invalid input alertContent := GetAlertReaderFromFile("./tests/invalidAlert_sample.json") - w = lapi.RecordResponse("POST", "/v1/alerts", alertContent) + w = lapi.RecordResponse("POST", "/v1/alerts", alertContent, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, "{\"message\":\"validation failure list:\\n0.scenario in body is required\\n0.scenario_hash in body is required\\n0.scenario_version in body is required\\n0.simulated in body is required\\n0.source in body is required\"}", w.Body.String()) @@ -177,13 +183,13 @@ func TestAlertListFilters(t *testing.T) { //bad filter - w := lapi.RecordResponse("GET", "/v1/alerts?test=test", alertContent) + w := lapi.RecordResponse("GET", "/v1/alerts?test=test", alertContent, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) //get without filters - w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) //check alert and decision assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") @@ -191,149 +197,149 @@ func TestAlertListFilters(t *testing.T) { //test decision_type filter (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ban", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ban", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test decision_type filter (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ratata", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ratata", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test scope (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?scope=Ip", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?scope=Ip", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test scope (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?scope=rarara", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?scope=rarara", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test scenario (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test scenario (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test ip (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=91.121.79.195", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test ip (bad value) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=99.122.77.195", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test ip (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?ip=gruueq", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?ip=gruueq", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) //test range (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test range - w = lapi.RecordResponse("GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test range (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?range=ratata", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?range=ratata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) //test since (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1h", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?since=1h", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test since (ok but yields no results) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1ns", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?since=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test since (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?since=1zuzu", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?since=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) //test until (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1ns", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?until=1ns", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test until (ok but no return) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1m", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?until=1m", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test until (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?until=1zuzu", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?until=1zuzu", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) //test simulated (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test simulated (ok) - w = lapi.RecordResponse("GET", "/v1/alerts?simulated=false", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?simulated=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test has active decision - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=true", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=true", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) //test has active decision - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=false", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, "null", w.Body.String()) //test has active decision (invalid value) - w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) @@ -345,7 +351,7 @@ func TestAlertBulkInsert(t *testing.T) { lapi.InsertAlertFromFile("./tests/alert_bulk.json") alertContent := GetAlertReaderFromFile("./tests/alert_bulk.json") - w := lapi.RecordResponse("GET", "/v1/alerts", alertContent) + w := lapi.RecordResponse("GET", "/v1/alerts", alertContent, "password") assert.Equal(t, 200, w.Code) } @@ -354,13 +360,13 @@ func TestListAlert(t *testing.T) { lapi.InsertAlertFromFile("./tests/alert_sample.json") // List Alert with invalid filter - w := lapi.RecordResponse("GET", "/v1/alerts?test=test", emptyBody) + w := lapi.RecordResponse("GET", "/v1/alerts?test=test", emptyBody, "password") assert.Equal(t, 500, w.Code) assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) // List Alert - w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody) + w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Contains(t, w.Body.String(), "crowdsecurity/test") } diff --git a/pkg/apiserver/apiserver.go b/pkg/apiserver/apiserver.go index 1cc32b570..9655f1dbd 100644 --- a/pkg/apiserver/apiserver.go +++ b/pkg/apiserver/apiserver.go @@ -2,8 +2,11 @@ package apiserver import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "io" + "io/ioutil" "net" "net/http" "os" @@ -11,6 +14,7 @@ import ( "time" "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers" + v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/database" @@ -240,12 +244,56 @@ func (s *APIServer) Router() (*gin.Engine, error) { return s.router, nil } +func (s *APIServer) GetTLSConfig() (*tls.Config, error) { + var caCert []byte + var err error + var caCertPool *x509.CertPool + var clientAuthType tls.ClientAuthType + + if s.TLS == nil { + return &tls.Config{}, nil + } + + if s.TLS.ClientVerification == "" { + //sounds like a sane default : verify client cert if given, but don't make it mandatory + clientAuthType = tls.VerifyClientCertIfGiven + } else { + clientAuthType, err = getTLSAuthType(s.TLS.ClientVerification) + if err != nil { + return nil, err + } + } + + if s.TLS.CACertPath != "" { + if clientAuthType > tls.RequestClientCert { + log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String()) + caCert, err = ioutil.ReadFile(s.TLS.CACertPath) + if err != nil { + return nil, errors.Wrap(err, "Error opening cert file") + } + caCertPool = x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + } + } + + return &tls.Config{ + ServerName: s.TLS.ServerName, //should it be removed ? + ClientAuth: clientAuthType, + ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details + }, nil +} + func (s *APIServer) Run() error { defer types.CatchPanic("lapi/runServer") - + tlsCfg, err := s.GetTLSConfig() + if err != nil { + return errors.Wrap(err, "while creating TLS config") + } s.httpServer = &http.Server{ - Addr: s.URL, - Handler: s.router, + Addr: s.URL, + Handler: s.router, + TLSConfig: tlsCfg, } if s.apic != nil { @@ -326,6 +374,36 @@ func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) { } func (s *APIServer) InitController() error { + err := s.controller.Init() + if err != nil { + return errors.Wrap(err, "controller init") + } + if s.TLS != nil { + var cacheExpiration time.Duration + if s.TLS.CacheExpiration != nil { + cacheExpiration = *s.TLS.CacheExpiration + } else { + cacheExpiration = time.Hour + } + s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath, + cacheExpiration, + log.WithFields(log.Fields{ + "component": "tls-auth", + "type": "agent", + })) + if err != nil { + return errors.Wrap(err, "while creating TLS auth for agents") + } + s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath, + cacheExpiration, + log.WithFields(log.Fields{ + "component": "tls-auth", + "type": "bouncer", + })) + if err != nil { + return errors.Wrap(err, "while creating TLS auth for bouncers") + } + } return err } diff --git a/pkg/apiserver/apiserver_test.go b/pkg/apiserver/apiserver_test.go index d1236681e..efdb05e16 100644 --- a/pkg/apiserver/apiserver_test.go +++ b/pkg/apiserver/apiserver_test.go @@ -275,7 +275,7 @@ func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) { if err != nil { return "", fmt.Errorf("unable to generate api key: %s", err) } - err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey)) + _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType) if err != nil { return "", fmt.Errorf("unable to create blocker: %s", err) } diff --git a/pkg/apiserver/controllers/controller.go b/pkg/apiserver/controllers/controller.go index a93aed6de..5dddf71fc 100644 --- a/pkg/apiserver/controllers/controller.go +++ b/pkg/apiserver/controllers/controller.go @@ -25,6 +25,7 @@ type Controller struct { Log *log.Logger ConsoleConfig *csconfig.ConsoleConfig TrustedIPs []net.IPNet + HandlerV1 *v1.Controller } func (c *Controller) Init() error { @@ -55,12 +56,22 @@ func serveHealth() http.HandlerFunc { } func (c *Controller) NewV1() error { + var err error - handlerV1, err := v1.New(c.DBClient, c.Ectx, c.Profiles, c.CAPIChan, c.PluginChannel, *c.ConsoleConfig, c.TrustedIPs) + v1Config := v1.ControllerV1Config{ + DbClient: c.DBClient, + Ctx: c.Ectx, + Profiles: c.Profiles, + CapiChan: c.CAPIChan, + PluginChannel: c.PluginChannel, + ConsoleConfig: *c.ConsoleConfig, + TrustedIPs: c.TrustedIPs, + } + + c.HandlerV1, err = v1.New(&v1Config) if err != nil { return err } - c.Router.GET("/health", gin.WrapF(serveHealth())) c.Router.Use(v1.PrometheusMiddleware()) c.Router.HandleMethodNotAllowed = true @@ -72,31 +83,32 @@ func (c *Controller) NewV1() error { }) groupV1 := c.Router.Group("/v1") - groupV1.POST("/watchers", handlerV1.CreateMachine) - groupV1.POST("/watchers/login", handlerV1.Middlewares.JWT.Middleware.LoginHandler) + groupV1.POST("/watchers", c.HandlerV1.CreateMachine) + groupV1.POST("/watchers/login", c.HandlerV1.Middlewares.JWT.Middleware.LoginHandler) jwtAuth := groupV1.Group("") - jwtAuth.GET("/refresh_token", handlerV1.Middlewares.JWT.Middleware.RefreshHandler) - jwtAuth.Use(handlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware()) + jwtAuth.GET("/refresh_token", c.HandlerV1.Middlewares.JWT.Middleware.RefreshHandler) + jwtAuth.Use(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware()) { - jwtAuth.POST("/alerts", handlerV1.CreateAlert) - jwtAuth.GET("/alerts", handlerV1.FindAlerts) - jwtAuth.HEAD("/alerts", handlerV1.FindAlerts) - jwtAuth.GET("/alerts/:alert_id", handlerV1.FindAlertByID) - jwtAuth.HEAD("/alerts/:alert_id", handlerV1.FindAlertByID) - jwtAuth.DELETE("/alerts", handlerV1.DeleteAlerts) - jwtAuth.DELETE("/decisions", handlerV1.DeleteDecisions) - jwtAuth.DELETE("/decisions/:decision_id", handlerV1.DeleteDecisionById) - jwtAuth.GET("/heartbeat", handlerV1.HeartBeat) + jwtAuth.POST("/alerts", c.HandlerV1.CreateAlert) + jwtAuth.GET("/alerts", c.HandlerV1.FindAlerts) + jwtAuth.HEAD("/alerts", c.HandlerV1.FindAlerts) + jwtAuth.GET("/alerts/:alert_id", c.HandlerV1.FindAlertByID) + jwtAuth.HEAD("/alerts/:alert_id", c.HandlerV1.FindAlertByID) + jwtAuth.DELETE("/alerts", c.HandlerV1.DeleteAlerts) + jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions) + jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById) + jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat) + } apiKeyAuth := groupV1.Group("") - apiKeyAuth.Use(handlerV1.Middlewares.APIKey.MiddlewareFunc(), v1.PrometheusBouncersMiddleware()) + apiKeyAuth.Use(c.HandlerV1.Middlewares.APIKey.MiddlewareFunc(), v1.PrometheusBouncersMiddleware()) { - apiKeyAuth.GET("/decisions", handlerV1.GetDecision) - apiKeyAuth.HEAD("/decisions", handlerV1.GetDecision) - apiKeyAuth.GET("/decisions/stream", handlerV1.StreamDecision) - apiKeyAuth.HEAD("/decisions/stream", handlerV1.StreamDecision) + apiKeyAuth.GET("/decisions", c.HandlerV1.GetDecision) + apiKeyAuth.HEAD("/decisions", c.HandlerV1.GetDecision) + apiKeyAuth.GET("/decisions/stream", c.HandlerV1.StreamDecision) + apiKeyAuth.HEAD("/decisions/stream", c.HandlerV1.StreamDecision) } return nil diff --git a/pkg/apiserver/controllers/v1/controller.go b/pkg/apiserver/controllers/v1/controller.go index c236cda7d..13e9f730c 100644 --- a/pkg/apiserver/controllers/v1/controller.go +++ b/pkg/apiserver/controllers/v1/controller.go @@ -4,6 +4,8 @@ import ( "context" "net" + //"github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers" + middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csplugin" @@ -23,19 +25,29 @@ type Controller struct { TrustedIPs []net.IPNet } -func New(dbClient *database.Client, ctx context.Context, profiles []*csconfig.ProfileCfg, capiChan chan []*models.Alert, pluginChannel chan csplugin.ProfileAlert, consoleConfig csconfig.ConsoleConfig, trustedIPs []net.IPNet) (*Controller, error) { +type ControllerV1Config struct { + DbClient *database.Client + Ctx context.Context + Profiles []*csconfig.ProfileCfg + CapiChan chan []*models.Alert + PluginChannel chan csplugin.ProfileAlert + ConsoleConfig csconfig.ConsoleConfig + TrustedIPs []net.IPNet +} + +func New(cfg *ControllerV1Config) (*Controller, error) { var err error v1 := &Controller{ - Ectx: ctx, - DBClient: dbClient, + Ectx: cfg.Ctx, + DBClient: cfg.DbClient, APIKeyHeader: middlewares.APIKeyHeader, - Profiles: profiles, - CAPIChan: capiChan, - PluginChannel: pluginChannel, - ConsoleConfig: consoleConfig, - TrustedIPs: trustedIPs, + Profiles: cfg.Profiles, + CAPIChan: cfg.CapiChan, + PluginChannel: cfg.PluginChannel, + ConsoleConfig: cfg.ConsoleConfig, + TrustedIPs: cfg.TrustedIPs, } - v1.Middlewares, err = middlewares.NewMiddlewares(dbClient) + v1.Middlewares, err = middlewares.NewMiddlewares(cfg.DbClient) if err != nil { return v1, err } diff --git a/pkg/apiserver/controllers/v1/machines.go b/pkg/apiserver/controllers/v1/machines.go index 1f4b339c0..b4f28d94f 100644 --- a/pkg/apiserver/controllers/v1/machines.go +++ b/pkg/apiserver/controllers/v1/machines.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" ) @@ -20,7 +21,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) { return } - _, err = c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false) + _, err = c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType) if err != nil { c.HandleDBErrors(gctx, err) return diff --git a/pkg/apiserver/decisions_test.go b/pkg/apiserver/decisions_test.go index 0447cbf9a..b830f136e 100644 --- a/pkg/apiserver/decisions_test.go +++ b/pkg/apiserver/decisions_test.go @@ -14,20 +14,20 @@ func TestDeleteDecisionRange(t *testing.T) { lapi.InsertAlertFromFile("./tests/alert_minibulk.json") // delete by ip wrong - w := lapi.RecordResponse("DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody) + w := lapi.RecordResponse("DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by range - w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) // delete by range : ensure it was already deleted - w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) } @@ -40,19 +40,19 @@ func TestDeleteDecisionFilter(t *testing.T) { // delete by ip wrong - w := lapi.RecordResponse("DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody) + w := lapi.RecordResponse("DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) // delete by ip good - w = lapi.RecordResponse("DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) // delete by scope/value - w = lapi.RecordResponse("DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, "password") assert.Equal(t, 200, w.Code) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) } @@ -65,7 +65,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision - w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody) + w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, "apikey") assert.Equal(t, 200, w.Code) decisions, code, err := readDecisionsGetResp(w) assert.Nil(t, err) @@ -80,7 +80,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : type filter - w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, "apikey") assert.Equal(t, 200, w.Code) decisions, code, err = readDecisionsGetResp(w) assert.Nil(t, err) @@ -98,7 +98,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : scope/value - w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, "apikey") assert.Equal(t, 200, w.Code) decisions, code, err = readDecisionsGetResp(w) assert.Nil(t, err) @@ -113,7 +113,7 @@ func TestGetDecisionFilters(t *testing.T) { // Get Decision : ip filter - w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, "apikey") assert.Equal(t, 200, w.Code) decisions, code, err = readDecisionsGetResp(w) assert.Nil(t, err) @@ -127,7 +127,7 @@ func TestGetDecisionFilters(t *testing.T) { // assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`) // Get decision : by range - w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, "apikey") assert.Equal(t, 200, w.Code) decisions, code, err = readDecisionsGetResp(w) assert.Nil(t, err) @@ -145,7 +145,7 @@ func TestGetDecision(t *testing.T) { lapi.InsertAlertFromFile("./tests/alert_sample.json") // Get Decision - w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody) + w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, "apikey") assert.Equal(t, 200, w.Code) decisions, code, err := readDecisionsGetResp(w) assert.Nil(t, err) @@ -165,7 +165,7 @@ func TestGetDecision(t *testing.T) { assert.Equal(t, int64(3), decisions[2].ID) // Get Decision with invalid filter. It should ignore this filter - w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, "apikey") assert.Equal(t, 200, w.Code) assert.Equal(t, 3, len(decisions)) } @@ -177,7 +177,7 @@ func TestDeleteDecisionByID(t *testing.T) { lapi.InsertAlertFromFile("./tests/alert_sample.json") //Have one alerts - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) + w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey") decisions, code, err := readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -185,21 +185,21 @@ func TestDeleteDecisionByID(t *testing.T) { assert.Equal(t, len(decisions["new"]), 1) // Delete alert with Invalid ID - w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, "password") assert.Equal(t, 400, w.Code) err_resp, _, err := readDecisionsErrorResp(w) assert.NoError(t, err) assert.Equal(t, err_resp["message"], "decision_id must be valid integer") // Delete alert with ID that not exist - w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, "password") assert.Equal(t, 500, w.Code) err_resp, _, err = readDecisionsErrorResp(w) assert.NoError(t, err) assert.Equal(t, err_resp["message"], "decision with id '100' doesn't exist: unable to delete") //Have one alerts - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -207,14 +207,14 @@ func TestDeleteDecisionByID(t *testing.T) { assert.Equal(t, len(decisions["new"]), 1) // Delete alert with valid ID - w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, "password") assert.Equal(t, 200, w.Code) resp, _, err := readDecisionsDeleteResp(w) assert.NoError(t, err) assert.Equal(t, resp.NbDeleted, "1") //Have one alert (because we delete an alert that has dup targets) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -229,14 +229,14 @@ func TestDeleteDecision(t *testing.T) { lapi.InsertAlertFromFile("./tests/alert_sample.json") // Delete alert with Invalid filter - w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody) + w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, "password") assert.Equal(t, 500, w.Code) err_resp, _, err := readDecisionsErrorResp(w) assert.NoError(t, err) assert.Equal(t, err_resp["message"], "'test' doesn't exist: invalid filter") // Delete all alert - w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, "password") assert.Equal(t, 200, w.Code) resp, _, err := readDecisionsDeleteResp(w) assert.NoError(t, err) @@ -251,7 +251,7 @@ func TestStreamStartDecisionDedup(t *testing.T) { lapi.InsertAlertFromFile("./tests/alert_sample.json") // Get Stream, we only get one decision (the longest one) - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) + w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey") decisions, code, err := readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -262,11 +262,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1") // id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip - w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody, "password") assert.Equal(t, 200, w.Code) // Get Stream, we only get one decision (the longest one, id=2) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -277,11 +277,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1") // We delete another decision, yet don't receive it in stream, since there's another decision on same IP - w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, "password") assert.Equal(t, 200, w.Code) // And get the remaining decision (1) - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -292,11 +292,11 @@ func TestStreamStartDecisionDedup(t *testing.T) { assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1") // We delete the last decision, we receive the delete order - w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, "password") assert.Equal(t, 200, w.Code) //and now we only get a deleted decision - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -317,7 +317,7 @@ func TestStreamDecisionDedup(t *testing.T) { time.Sleep(2 * time.Second) // Get Stream, we only get one decision (the longest one) - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) + w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey") decisions, code, err := readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -328,10 +328,10 @@ func TestStreamDecisionDedup(t *testing.T) { assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1") // id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip - w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody, "password") assert.Equal(t, 200, w.Code) - w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody, "apikey") assert.Equal(t, err, nil) decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) @@ -339,10 +339,10 @@ func TestStreamDecisionDedup(t *testing.T) { assert.Equal(t, len(decisions["deleted"]), 0) assert.Equal(t, len(decisions["new"]), 0) // We delete another decision, yet don't receive it in stream, since there's another decision on same IP - w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, "password") assert.Equal(t, 200, w.Code) - w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -350,10 +350,10 @@ func TestStreamDecisionDedup(t *testing.T) { assert.Equal(t, len(decisions["new"]), 0) // We delete the last decision, we receive the delete order - w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody) + w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, "password") assert.Equal(t, 200, w.Code) - w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, code, 200) @@ -371,7 +371,7 @@ func TestStreamDecisionFilters(t *testing.T) { // Create Valid Alert lapi.InsertAlertFromFile("./tests/alert_stream_fixture.json") - w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) + w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey") decisions, code, err := readDecisionsStreamResp(w) assert.Equal(t, 200, code) @@ -392,7 +392,7 @@ func TestStreamDecisionFilters(t *testing.T) { assert.Equal(t, *decisions["new"][2].Scenario, "crowdsecurity/ddos") // test filter scenarios_not_containing - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_not_containing=http", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_not_containing=http", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, 200, code) @@ -402,7 +402,7 @@ func TestStreamDecisionFilters(t *testing.T) { assert.Equal(t, decisions["new"][1].ID, int64(3)) // test filter scenarios_containing - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_containing=http", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_containing=http", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, 200, code) @@ -411,7 +411,7 @@ func TestStreamDecisionFilters(t *testing.T) { assert.Equal(t, decisions["new"][0].ID, int64(1)) // test filters both by scenarios_not_containing and scenarios_containing - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh&scenarios_containing=ddos", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh&scenarios_containing=ddos", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, 200, code) @@ -420,7 +420,7 @@ func TestStreamDecisionFilters(t *testing.T) { assert.Equal(t, decisions["new"][0].ID, int64(3)) // test filter by origin - w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&origins=test1,test2", emptyBody) + w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&origins=test1,test2", emptyBody, "apikey") decisions, code, err = readDecisionsStreamResp(w) assert.Equal(t, err, nil) assert.Equal(t, 200, code) diff --git a/pkg/apiserver/heartbeat_test.go b/pkg/apiserver/heartbeat_test.go index f09b0d37e..01fac2f0b 100644 --- a/pkg/apiserver/heartbeat_test.go +++ b/pkg/apiserver/heartbeat_test.go @@ -9,9 +9,9 @@ import ( func TestHeartBeat(t *testing.T) { lapi := SetupLAPITest(t) - w := lapi.RecordResponse("GET", "/v1/heartbeat", emptyBody) + w := lapi.RecordResponse("GET", "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 200, w.Code) - w = lapi.RecordResponse("POST", "/v1/heartbeat", emptyBody) + w = lapi.RecordResponse("POST", "/v1/heartbeat", emptyBody, "password") assert.Equal(t, 405, w.Code) } diff --git a/pkg/apiserver/middlewares/v1/api_key.go b/pkg/apiserver/middlewares/v1/api_key.go index 23267d36f..e129f20cc 100644 --- a/pkg/apiserver/middlewares/v1/api_key.go +++ b/pkg/apiserver/middlewares/v1/api_key.go @@ -9,6 +9,8 @@ import ( "strings" "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" + "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" ) @@ -21,6 +23,7 @@ var ( type APIKey struct { HeaderName string DbClient *database.Client + TlsAuth *TLSAuth } func GenerateAPIKey(n int) (string, error) { @@ -35,6 +38,7 @@ func NewAPIKey(dbClient *database.Client) *APIKey { return &APIKey{ HeaderName: APIKeyHeader, DbClient: dbClient, + TlsAuth: &TLSAuth{}, } } @@ -49,34 +53,132 @@ func HashSHA512(str string) string { func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { return func(c *gin.Context) { - val, ok := c.Request.Header[APIKeyHeader] - if !ok { - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return - } + var bouncer *ent.Bouncer + var err error - hashStr := HashSHA512(val[0]) - bouncer, err := a.DbClient.SelectBouncer(hashStr) - if err != nil { - log.Errorf("auth api key error: %s", err) - c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) - c.Abort() - return + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { + if a.TlsAuth == nil { + log.WithField("ip", c.ClientIP()).Error("TLS Auth is not configured but client presented a certificate") + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } + validCert, extractedCN, err := a.TlsAuth.ValidateCert(c) + if !validCert { + log.WithField("ip", c.ClientIP()).Errorf("invalid client certificate: %s", err) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } + if err != nil { + log.WithField("ip", c.ClientIP()).Error(err) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } + bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + bouncer, err = a.DbClient.SelectBouncerByName(bouncerName) + //This is likely not the proper way, but isNotFound does not seem to work + if err != nil && strings.Contains(err.Error(), "bouncer not found") { + //Because we have a valid cert, automatically create the bouncer in the database if it does not exist + //Set a random API key, but it will never be used + apiKey, err := GenerateAPIKey(64) + if err != nil { + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "cn": extractedCN, + }).Errorf("error generating mock api key: %s", err) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "cn": extractedCN, + }).Infof("Creating bouncer %s", bouncerName) + bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType) + if err != nil { + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "cn": extractedCN, + }).Errorf("creating bouncer db entry : %s", err) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } + } else if err != nil { + //error while selecting bouncer + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "cn": extractedCN, + }).Errorf("while selecting bouncers: %s", err) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } else { + //bouncer was found in DB + if bouncer.AuthType != types.TlsAuthType { + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "cn": extractedCN, + }).Errorf("bouncer isn't allowed to auth by TLS") + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } + } + } else { + //API Key Authentication + val, ok := c.Request.Header[APIKeyHeader] + if !ok { + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + }).Errorf("API key not found") + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } + hashStr := HashSHA512(val[0]) + bouncer, err = a.DbClient.SelectBouncer(hashStr) + if err != nil { + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + }).Errorf("while fetching bouncer info: %s", err) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } + if bouncer.AuthType != types.ApiKeyAuthType { + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + }).Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return + } } if bouncer == nil { + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + }).Errorf("bouncer not found") c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() return } + //maybe we want to store the whole bouncer object in the context instead, this would avoid another db query + //in StreamDecision c.Set("BOUNCER_NAME", bouncer.Name) + c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey) if bouncer.IPAddress == "" { err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) if err != nil { - log.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "name": bouncer.Name, + }).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() return @@ -87,7 +189,10 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress) err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) if err != nil { - log.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "name": bouncer.Name, + }).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort() return @@ -97,13 +202,19 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { useragent := strings.Split(c.Request.UserAgent(), "/") if len(useragent) != 2 { - log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), c.ClientIP()) + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "name": bouncer.Name, + }).Warningf("bad user agent '%s'", c.Request.UserAgent()) useragent = []string{c.Request.UserAgent(), "N/A"} } if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { - log.Errorf("failed to update bouncer version and type from '%s' (%s): %s", c.Request.UserAgent(), c.ClientIP(), err) + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "name": bouncer.Name, + }).Errorf("failed to update bouncer version and type: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.Abort() return diff --git a/pkg/apiserver/middlewares/v1/jwt.go b/pkg/apiserver/middlewares/v1/jwt.go index 51b82db82..3ae8f6d0b 100644 --- a/pkg/apiserver/middlewares/v1/jwt.go +++ b/pkg/apiserver/middlewares/v1/jwt.go @@ -3,14 +3,17 @@ package v1 import ( "crypto/rand" "fmt" + "net/http" "os" "strings" "time" jwt "github.com/appleboy/gin-jwt/v2" "github.com/crowdsecurity/crowdsec/pkg/database" + "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/models" + "github.com/crowdsecurity/crowdsec/pkg/types" "github.com/gin-gonic/gin" "github.com/go-openapi/strfmt" "github.com/pkg/errors" @@ -23,6 +26,7 @@ var identityKey = "id" type JWT struct { Middleware *jwt.GinJWTMiddleware DbClient *database.Client + TlsAuth *TLSAuth } func PayloadFunc(data interface{}) jwt.MapClaims { @@ -46,35 +50,109 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { var loginInput models.WatcherAuthRequest var scenarios string var err error - if err := c.ShouldBindJSON(&loginInput); err != nil { - return "", errors.Wrap(err, "missing") - } - if err := loginInput.Validate(strfmt.Default); err != nil { - return "", errors.New("input format error") - } - machineID := *loginInput.MachineID - password := *loginInput.Password - scenariosInput := loginInput.Scenarios + var scenariosInput []string + var clientMachine *ent.Machine + var machineID string + var password strfmt.Password - machine, err := j.DbClient.Ent.Machine.Query(). - Where(machine.MachineId(machineID)). - First(j.DbClient.CTX) - if err != nil { - log.Printf("Error machine login for %s : %+v ", machineID, err) - return nil, err - } + if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 { + if j.TlsAuth == nil { + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return nil, errors.New("TLS auth is not configured") + } + validCert, extractedCN, err := j.TlsAuth.ValidateCert(c) + if err != nil { + log.Error(err) + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return nil, errors.Wrap(err, "while trying to validate client cert") + } + if !validCert { + c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) + c.Abort() + return nil, fmt.Errorf("failed cert authentication") + } - if machine == nil { - log.Errorf("Nothing for '%s'", machineID) - return nil, jwt.ErrFailedAuthentication - } + machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP()) + clientMachine, err = j.DbClient.Ent.Machine.Query(). + Where(machine.MachineId(machineID)). + First(j.DbClient.CTX) + if ent.IsNotFound(err) { + //Machine was not found, let's create it + log.Printf("machine %s not found, create it", machineID) + //let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli) + pwd, err := GenerateAPIKey(64) + if err != nil { + log.WithFields(log.Fields{ + "ip": c.ClientIP(), + "cn": extractedCN, + }).Errorf("error generating password: %s", err) + return nil, fmt.Errorf("error generating password") + } + password := strfmt.Password(pwd) + clientMachine, err = j.DbClient.CreateMachine(&machineID, &password, "", true, true, types.TlsAuthType) + if err != nil { + return "", errors.Wrapf(err, "while creating machine entry for %s", machineID) + } + } else if err != nil { + return "", errors.Wrapf(err, "while selecting machine entry for %s", machineID) + } else { + if clientMachine.AuthType != types.TlsAuthType { + return "", errors.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", machineID, clientMachine.AuthType) + } + machineID = clientMachine.MachineId + loginInput := struct { + Scenarios []string `json:"scenarios"` + }{ + Scenarios: []string{}, + } + err := c.ShouldBindJSON(&loginInput) + if err != nil { + return "", errors.Wrap(err, "missing scenarios list in login request for TLS auth") + } + scenariosInput = loginInput.Scenarios + } - if !machine.IsValidated { - return nil, fmt.Errorf("machine %s not validated", machineID) - } + } else { + //normal auth - if err = bcrypt.CompareHashAndPassword([]byte(machine.Password), []byte(password)); err != nil { - return nil, jwt.ErrFailedAuthentication + if err := c.ShouldBindJSON(&loginInput); err != nil { + return "", errors.Wrap(err, "missing") + } + if err := loginInput.Validate(strfmt.Default); err != nil { + return "", errors.New("input format error") + } + machineID = *loginInput.MachineID + password = *loginInput.Password + scenariosInput = loginInput.Scenarios + + clientMachine, err = j.DbClient.Ent.Machine.Query(). + Where(machine.MachineId(machineID)). + First(j.DbClient.CTX) + if err != nil { + log.Printf("Error machine login for %s : %+v ", machineID, err) + return nil, err + } + + if clientMachine == nil { + log.Errorf("Nothing for '%s'", machineID) + return nil, jwt.ErrFailedAuthentication + } + + if clientMachine.AuthType != types.PasswordAuthType { + return nil, errors.Errorf("machine %s attempted to auth with password but it is configured to use %s", machineID, clientMachine.AuthType) + } + + if !clientMachine.IsValidated { + return nil, fmt.Errorf("machine %s not validated", machineID) + } + + if err = bcrypt.CompareHashAndPassword([]byte(clientMachine.Password), []byte(password)); err != nil { + return nil, jwt.ErrFailedAuthentication + } + + //end of normal auth } if len(scenariosInput) > 0 { @@ -85,26 +163,26 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { scenarios += "," + scenario } } - err = j.DbClient.UpdateMachineScenarios(scenarios, machine.ID) + err = j.DbClient.UpdateMachineScenarios(scenarios, clientMachine.ID) if err != nil { log.Errorf("Failed to update scenarios list for '%s': %s\n", machineID, err) return nil, jwt.ErrFailedAuthentication } } - if machine.IpAddress == "" { - err = j.DbClient.UpdateMachineIP(c.ClientIP(), machine.ID) + if clientMachine.IpAddress == "" { + err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID) if err != nil { log.Errorf("Failed to update ip address for '%s': %s\n", machineID, err) return nil, jwt.ErrFailedAuthentication } } - if machine.IpAddress != c.ClientIP() && machine.IpAddress != "" { - log.Warningf("new IP address detected for machine '%s': %s (old: %s)", machine.MachineId, c.ClientIP(), machine.IpAddress) - err = j.DbClient.UpdateMachineIP(c.ClientIP(), machine.ID) + if clientMachine.IpAddress != c.ClientIP() && clientMachine.IpAddress != "" { + log.Warningf("new IP address detected for machine '%s': %s (old: %s)", clientMachine.MachineId, c.ClientIP(), clientMachine.IpAddress) + err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID) if err != nil { - log.Errorf("Failed to update ip address for '%s': %s\n", machine.MachineId, err) + log.Errorf("Failed to update ip address for '%s': %s\n", clientMachine.MachineId, err) return nil, jwt.ErrFailedAuthentication } } @@ -115,12 +193,11 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) { return nil, jwt.ErrFailedAuthentication } - if err := j.DbClient.UpdateMachineVersion(useragent[1], machine.ID); err != nil { - log.Errorf("unable to update machine '%s' version '%s': %s", machine.MachineId, useragent[1], err) + if err := j.DbClient.UpdateMachineVersion(useragent[1], clientMachine.ID); err != nil { + log.Errorf("unable to update machine '%s' version '%s': %s", clientMachine.MachineId, useragent[1], err) log.Errorf("bad user agent from : %s", c.ClientIP()) return nil, jwt.ErrFailedAuthentication } - return &models.WatcherAuthRequest{ MachineID: &machineID, }, nil @@ -178,6 +255,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { jwtMiddleware := &JWT{ DbClient: dbClient, + TlsAuth: &TLSAuth{}, } ret, err := jwt.New(&jwt.GinJWTMiddleware{ @@ -195,15 +273,15 @@ func NewJWT(dbClient *database.Client) (*JWT, error) { TokenHeadName: "Bearer", TimeFunc: time.Now, }) + if err != nil { + return &JWT{}, err + } errInit := ret.MiddlewareInit() if errInit != nil { return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) } + jwtMiddleware.Middleware = ret - if err != nil { - return &JWT{}, err - } - - return &JWT{Middleware: ret}, nil + return jwtMiddleware, nil } diff --git a/pkg/apiserver/middlewares/v1/middlewares.go b/pkg/apiserver/middlewares/v1/middlewares.go index 7777f5857..26879bd8e 100644 --- a/pkg/apiserver/middlewares/v1/middlewares.go +++ b/pkg/apiserver/middlewares/v1/middlewares.go @@ -18,6 +18,5 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) { } ret.APIKey = NewAPIKey(dbClient) - return ret, nil } diff --git a/pkg/apiserver/middlewares/v1/tls_auth.go b/pkg/apiserver/middlewares/v1/tls_auth.go new file mode 100644 index 000000000..a65a52d55 --- /dev/null +++ b/pkg/apiserver/middlewares/v1/tls_auth.go @@ -0,0 +1,256 @@ +package v1 + +import ( + "bytes" + "crypto" + "crypto/x509" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/ocsp" +) + +type TLSAuth struct { + AllowedOUs []string + CrlPath string + revokationCache map[string]cacheEntry + cacheExpiration time.Duration + logger *log.Entry +} + +type cacheEntry struct { + revoked bool + err error + timestamp time.Time +} + +func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) { + req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256}) + if err != nil { + ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err) + return nil, err + } + httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req)) + if err != nil { + ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP") + return nil, err + } + ocspURL, err := url.Parse(server) + if err != nil { + ta.logger.Error("TLSAuth: cannot parse OCSP URL") + return nil, err + } + httpRequest.Header.Add("Content-Type", "application/ocsp-request") + httpRequest.Header.Add("Accept", "application/ocsp-response") + httpRequest.Header.Add("host", ocspURL.Host) + httpClient := &http.Client{} + httpResponse, err := httpClient.Do(httpRequest) + if err != nil { + ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP") + return nil, err + } + defer httpResponse.Body.Close() + output, err := ioutil.ReadAll(httpResponse.Body) + if err != nil { + ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP") + return nil, err + } + ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer) + return ocspResponse, err +} + +func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool { + now := time.Now().UTC() + + if cert.NotAfter.UTC().Before(now) { + ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC()) + return true + } + if cert.NotBefore.UTC().After(now) { + ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC()) + return true + } + return false +} + +func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { + if cert.OCSPServer == nil || (cert.OCSPServer != nil && len(cert.OCSPServer) == 0) { + ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification") + return false, nil + } + for _, server := range cert.OCSPServer { + ocspResponse, err := ta.ocspQuery(server, cert, issuer) + if err != nil { + ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err) + continue + } + switch ocspResponse.Status { + case ocsp.Good: + return false, nil + case ocsp.Revoked: + return true, fmt.Errorf("client certificate is revoked by server %s", server) + case ocsp.Unknown: + log.Debugf("unknow OCSP status for server %s", server) + continue + } + } + log.Infof("Could not get any valid OCSP response, assuming the cert is revoked") + return true, nil +} + +func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) { + if ta.CrlPath == "" { + ta.logger.Warn("no crl_path, skipping CRL check") + return false, nil + } + crlContent, err := ioutil.ReadFile(ta.CrlPath) + if err != nil { + ta.logger.Warnf("could not read CRL file, skipping check: %s", err) + return false, nil + } + crl, err := x509.ParseCRL(crlContent) + if err != nil { + ta.logger.Warnf("could not parse CRL file, skipping check: %s", err) + return false, nil + } + if crl.HasExpired(time.Now().UTC()) { + ta.logger.Warn("CRL has expired, will still validate the cert against it.") + } + for _, revoked := range crl.TBSCertList.RevokedCertificates { + if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 { + return true, fmt.Errorf("client certificate is revoked by CRL") + } + } + return false, nil +} + +func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { + sn := cert.SerialNumber.String() + if cacheValue, ok := ta.revokationCache[sn]; ok { + if time.Now().UTC().Sub(cacheValue.timestamp) < ta.cacheExpiration { + ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t | %s", sn, cacheValue.revoked, cacheValue.err) + return cacheValue.revoked, cacheValue.err + } else { + ta.logger.Debugf("TLSAuth: cached value expired, removing from cache") + delete(ta.revokationCache, sn) + } + } else { + ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn) + } + revoked, err := ta.isOCSPRevoked(cert, issuer) + if err != nil { + ta.revokationCache[sn] = cacheEntry{ + revoked: revoked, + err: err, + timestamp: time.Now().UTC(), + } + return true, err + } + if revoked { + ta.revokationCache[sn] = cacheEntry{ + revoked: revoked, + err: err, + timestamp: time.Now().UTC(), + } + return true, nil + } + revoked, err = ta.isCRLRevoked(cert) + ta.revokationCache[sn] = cacheEntry{ + revoked: revoked, + err: err, + timestamp: time.Now().UTC(), + } + return revoked, err +} + +func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) { + if ta.isExpired(cert) { + return true, nil + } + revoked, err := ta.isRevoked(cert, issuer) + if err != nil { + //Fail securely, if we can't check the revokation status, let's consider the cert invalid + //We may change this in the future based on users feedback, but this seems the most sensible thing to do + return true, errors.Wrap(err, "could not check for client certification revokation status") + } + + return revoked, nil +} + +func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error { + for _, ou := range allowedOus { + //disallow empty ou + if ou == "" { + return fmt.Errorf("empty ou isn't allowed") + } + //drop & warn on duplicate ou + ok := true + for _, validOu := range ta.AllowedOUs { + if validOu == ou { + ta.logger.Warningf("dropping duplicate ou %s", ou) + ok = false + } + } + if ok { + ta.AllowedOUs = append(ta.AllowedOUs, ou) + } + } + return nil +} + +func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) { + //Checks cert validity, Returns true + CN if client cert matches requested OU + var clientCert *x509.Certificate + if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 { + //do not error if it's not TLS or there are no peer certs + return false, "", nil + } + + if len(c.Request.TLS.VerifiedChains) > 0 { + validOU := false + clientCert = c.Request.TLS.VerifiedChains[0][0] + for _, ou := range clientCert.Subject.OrganizationalUnit { + for _, allowedOu := range ta.AllowedOUs { + if allowedOu == ou { + validOU = true + break + } + } + } + if !validOU { + return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)", + clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) + } + revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1]) + if err != nil { + ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err) + return false, "", errors.Wrap(err, "could not check for client certification revokation status") + } + if revoked { + return false, "", fmt.Errorf("client certificate is revoked") + } + ta.logger.Infof("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs) + return true, clientCert.Subject.CommonName, nil + } + return false, "", fmt.Errorf("no verified cert in request") +} + +func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) { + ta := &TLSAuth{ + revokationCache: map[string]cacheEntry{}, + cacheExpiration: cacheExpiration, + CrlPath: crlPath, + logger: logger, + } + err := ta.SetAllowedOu(allowedOus) + if err != nil { + return nil, err + } + return ta, nil +} diff --git a/pkg/apiserver/utils.go b/pkg/apiserver/utils.go new file mode 100644 index 000000000..409d79b01 --- /dev/null +++ b/pkg/apiserver/utils.go @@ -0,0 +1,27 @@ +package apiserver + +import ( + "crypto/tls" + "fmt" + + log "github.com/sirupsen/logrus" +) + +func getTLSAuthType(authType string) (tls.ClientAuthType, error) { + switch authType { + case "NoClientCert": + return tls.NoClientCert, nil + case "RequestClientCert": + log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") + return tls.RequestClientCert, nil + case "RequireAnyClientCert": + log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead") + return tls.RequireAnyClientCert, nil + case "VerifyClientCertIfGiven": + return tls.VerifyClientCertIfGiven, nil + case "RequireAndVerifyClientCert": + return tls.RequireAndVerifyClientCert, nil + default: + return 0, fmt.Errorf("unknown TLS client_verification value: %s", authType) + } +} diff --git a/pkg/csconfig/api.go b/pkg/csconfig/api.go index b546b1b56..4751b5d7e 100644 --- a/pkg/csconfig/api.go +++ b/pkg/csconfig/api.go @@ -1,10 +1,13 @@ package csconfig import ( + "crypto/tls" + "crypto/x509" "fmt" "io/ioutil" "net" "strings" + "time" "github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/yamlpatch" @@ -19,9 +22,12 @@ type APICfg struct { } type ApiCredentialsCfg struct { - URL string `yaml:"url,omitempty" json:"url,omitempty"` - Login string `yaml:"login,omitempty" json:"login,omitempty"` - Password string `yaml:"password,omitempty" json:"-"` + URL string `yaml:"url,omitempty" json:"url,omitempty"` + Login string `yaml:"login,omitempty" json:"login,omitempty"` + Password string `yaml:"password,omitempty" json:"-"` + CACertPath string `yaml:"ca_cert_path,omitempty"` + KeyPath string `yaml:"key_path,omitempty"` + CertPath string `yaml:"cert_path,omitempty"` } /*global api config (for lapi->oapi)*/ @@ -73,11 +79,34 @@ func (l *LocalApiClientCfg) Load() error { l.Credentials.URL = l.Credentials.URL + "/" } } + + if l.Credentials.Login != "" && (l.Credentials.CACertPath != "" || l.Credentials.CertPath != "" || l.Credentials.KeyPath != "") { + return fmt.Errorf("user/password authentication and TLS authentication are mutually exclusive") + } + if l.InsecureSkipVerify == nil { apiclient.InsecureSkipVerify = false } else { apiclient.InsecureSkipVerify = *l.InsecureSkipVerify } + + if l.Credentials.CACertPath != "" && l.Credentials.CertPath != "" && l.Credentials.KeyPath != "" { + cert, err := tls.LoadX509KeyPair(l.Credentials.CertPath, l.Credentials.KeyPath) + if err != nil { + return errors.Wrapf(err, "failed to load api client certificate") + } + + caCert, err := ioutil.ReadFile(l.Credentials.CACertPath) + if err != nil { + return errors.Wrapf(err, "failed to load cacert") + } + + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + apiclient.Cert = &cert + apiclient.CaCertPool = caCertPool + } return nil } @@ -128,8 +157,15 @@ type LocalApiServerCfg struct { } type TLSCfg struct { - CertFilePath string `yaml:"cert_file"` - KeyFilePath string `yaml:"key_file"` + CertFilePath string `yaml:"cert_file"` + KeyFilePath string `yaml:"key_file"` + ClientVerification string `yaml:"client_verification,omitempty"` + ServerName string `yaml:"server_name"` + CACertPath string `yaml:"ca_cert_path"` + AllowedAgentsOU []string `yaml:"agents_allowed_ou"` + AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"` + CRLPath string `yaml:"crl_path"` + CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"` } func (c *Config) LoadAPIServer() error { diff --git a/pkg/csconfig/database.go b/pkg/csconfig/database.go index 979b0da8c..b2fd18d7d 100644 --- a/pkg/csconfig/database.go +++ b/pkg/csconfig/database.go @@ -2,6 +2,7 @@ package csconfig import ( "fmt" + "time" "github.com/crowdsecurity/crowdsec/pkg/types" log "github.com/sirupsen/logrus" @@ -23,9 +24,20 @@ type DatabaseCfg struct { MaxOpenConns *int `yaml:"max_open_conns,omitempty"` } +type AuthGCCfg struct { + Cert *string `yaml:"cert,omitempty"` + CertDuration *time.Duration + Api *string `yaml:"api_key,omitempty"` + ApiDuration *time.Duration + LoginPassword *string `yaml:"login_password,omitempty"` + LoginPasswordDuration *time.Duration +} + type FlushDBCfg struct { - MaxItems *int `yaml:"max_items"` - MaxAge *string `yaml:"max_age"` + MaxItems *int `yaml:"max_items,omitempty"` + MaxAge *string `yaml:"max_age,omitempty"` + BouncersGC *AuthGCCfg `yaml:"bouncers_autodelete,omitempty"` + AgentsGC *AuthGCCfg `yaml:"agents_autodelete,omitempty"` } func (c *Config) LoadDBConfig() error { diff --git a/pkg/database/alerts.go b/pkg/database/alerts.go index aabdcc3c3..d9d833c4f 100644 --- a/pkg/database/alerts.go +++ b/pkg/database/alerts.go @@ -7,10 +7,13 @@ import ( "strings" "time" + "github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event" + "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" "github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/types" @@ -890,6 +893,77 @@ func (c *Client) FlushOrphans() { } } +func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error { + log.Printf("starting FlushAgentsAndBouncers") + if bouncersCfg != nil { + if bouncersCfg.ApiDuration != nil { + log.Printf("trying to delete old bouncers from api") + deletionCount, err := c.Ent.Bouncer.Delete().Where( + bouncer.LastPullLTE(time.Now().UTC().Add(*bouncersCfg.ApiDuration)), + ).Where( + bouncer.AuthTypeEQ(types.ApiKeyAuthType), + ).Exec(c.CTX) + if err != nil { + c.Log.Errorf("while auto-deleting expired bouncers (api key) : %s", err) + } else if deletionCount > 0 { + c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) + } + } + if bouncersCfg.CertDuration != nil { + log.Printf("trying to delete old bouncers from cert") + + deletionCount, err := c.Ent.Bouncer.Delete().Where( + bouncer.LastPullLTE(time.Now().UTC().Add(*bouncersCfg.CertDuration)), + ).Where( + bouncer.AuthTypeEQ(types.TlsAuthType), + ).Exec(c.CTX) + if err != nil { + c.Log.Errorf("while auto-deleting expired bouncers (api key) : %s", err) + } else if deletionCount > 0 { + c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount) + } + } + } + + if agentsCfg != nil { + if agentsCfg.CertDuration != nil { + log.Printf("trying to delete old agents from cert") + + deletionCount, err := c.Ent.Machine.Delete().Where( + machine.LastPushLTE(time.Now().UTC().Add(*agentsCfg.CertDuration)), + ).Where( + machine.Not(machine.HasAlerts()), + ).Where( + machine.AuthTypeEQ(types.TlsAuthType), + ).Exec(c.CTX) + log.Printf("deleted %d entries", deletionCount) + if err != nil { + c.Log.Errorf("while auto-deleting expired machine (cert) : %s", err) + } else if deletionCount > 0 { + c.Log.Infof("deleted %d expired machine (cert auth)", deletionCount) + } + } + if agentsCfg.LoginPasswordDuration != nil { + log.Printf("trying to delete old agents from password") + + deletionCount, err := c.Ent.Machine.Delete().Where( + machine.LastPushLTE(time.Now().UTC().Add(*agentsCfg.LoginPasswordDuration)), + ).Where( + machine.Not(machine.HasAlerts()), + ).Where( + machine.AuthTypeEQ(types.PasswordAuthType), + ).Exec(c.CTX) + log.Printf("deleted %d entries", deletionCount) + if err != nil { + c.Log.Errorf("while auto-deleting expired machine (password) : %s", err) + } else if deletionCount > 0 { + c.Log.Infof("deleted %d expired machine (password auth)", deletionCount) + } + } + } + return nil +} + func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { var deletedByAge int var deletedByNbItem int diff --git a/pkg/database/bouncers.go b/pkg/database/bouncers.go index fecabcb73..808786d4d 100644 --- a/pkg/database/bouncers.go +++ b/pkg/database/bouncers.go @@ -18,6 +18,15 @@ func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) { return result, nil } +func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) { + result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX) + if err != nil { + return &ent.Bouncer{}, errors.Wrapf(QueryFail, "select bouncer: %s", err) + } + + return result, nil +} + func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { result, err := c.Ent.Bouncer.Query().All(c.CTX) if err != nil { @@ -26,20 +35,21 @@ func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { return result, nil } -func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string) error { - _, err := c.Ent.Bouncer. +func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) { + bouncer, err := c.Ent.Bouncer. Create(). SetName(name). SetAPIKey(apiKey). SetRevoked(false). + SetAuthType(authType). Save(c.CTX) if err != nil { if ent.IsConstraintError(err) { - return fmt.Errorf("bouncer %s already exists", name) + return nil, fmt.Errorf("bouncer %s already exists", name) } - return fmt.Errorf("unable to save api key in database: %s", err) + return nil, fmt.Errorf("unable to save api key in database: %s", err) } - return nil + return bouncer, nil } func (c *Client) DeleteBouncer(name string) error { diff --git a/pkg/database/database.go b/pkg/database/database.go index cd745fa26..30d38c4a2 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -122,14 +122,61 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched if config.MaxItems != nil { maxItems = *config.MaxItems } - if config.MaxAge != nil && *config.MaxAge != "" { maxAge = *config.MaxAge } - // Init & Start cronjob every minute + + // Init & Start cronjob every minute for alerts scheduler := gocron.NewScheduler(time.UTC) - job, _ := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) + job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) + if err != nil { + return nil, errors.Wrap(err, "while starting FlushAlerts scheduler") + } job.SingletonMode() + // Init & Start cronjob every hour for bouncers/agents + if config.AgentsGC != nil { + if config.AgentsGC.Cert != nil { + duration, err := types.ParseDuration(*config.AgentsGC.Cert) + if err != nil { + return nil, errors.Wrap(err, "while parsing agents cert auto-delete duration") + } + config.AgentsGC.CertDuration = &duration + } + if config.AgentsGC.LoginPassword != nil { + duration, err := types.ParseDuration(*config.AgentsGC.LoginPassword) + if err != nil { + return nil, errors.Wrap(err, "while parsing agents login/password auto-delete duration") + } + config.AgentsGC.LoginPasswordDuration = &duration + } + if config.AgentsGC.Api != nil { + log.Warningf("agents auto-delete for API auth is not supported (use cert or login_password)") + } + } + if config.BouncersGC != nil { + if config.BouncersGC.Cert != nil { + duration, err := types.ParseDuration(*config.BouncersGC.Cert) + if err != nil { + return nil, errors.Wrap(err, "while parsing bouncers cert auto-delete duration") + } + config.BouncersGC.CertDuration = &duration + } + if config.BouncersGC.Api != nil { + duration, err := types.ParseDuration(*config.BouncersGC.Api) + if err != nil { + return nil, errors.Wrap(err, "while parsing bouncers api auto-delete duration") + } + config.BouncersGC.ApiDuration = &duration + } + if config.BouncersGC.LoginPassword != nil { + log.Warningf("bouncers auto-delete for login/password auth is not supported (use cert or api)") + } + } + baJob, err := scheduler.Every(1).Minute().Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC) + if err != nil { + return nil, errors.Wrap(err, "while starting FlushAgentsAndBouncers scheduler") + } + baJob.SingletonMode() scheduler.StartAsync() return scheduler, nil diff --git a/pkg/database/ent/bouncer.go b/pkg/database/ent/bouncer.go index 084a52d90..ad6fa1cf2 100644 --- a/pkg/database/ent/bouncer.go +++ b/pkg/database/ent/bouncer.go @@ -36,6 +36,8 @@ type Bouncer struct { Until time.Time `json:"until"` // LastPull holds the value of the "last_pull" field. LastPull time.Time `json:"last_pull"` + // AuthType holds the value of the "auth_type" field. + AuthType string `json:"auth_type"` } // scanValues returns the types for scanning values from sql.Rows. @@ -47,7 +49,7 @@ func (*Bouncer) scanValues(columns []string) ([]interface{}, error) { values[i] = new(sql.NullBool) case bouncer.FieldID: values[i] = new(sql.NullInt64) - case bouncer.FieldName, bouncer.FieldAPIKey, bouncer.FieldIPAddress, bouncer.FieldType, bouncer.FieldVersion: + case bouncer.FieldName, bouncer.FieldAPIKey, bouncer.FieldIPAddress, bouncer.FieldType, bouncer.FieldVersion, bouncer.FieldAuthType: values[i] = new(sql.NullString) case bouncer.FieldCreatedAt, bouncer.FieldUpdatedAt, bouncer.FieldUntil, bouncer.FieldLastPull: values[i] = new(sql.NullTime) @@ -134,6 +136,12 @@ func (b *Bouncer) assignValues(columns []string, values []interface{}) error { } else if value.Valid { b.LastPull = value.Time } + case bouncer.FieldAuthType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field auth_type", values[i]) + } else if value.Valid { + b.AuthType = value.String + } } } return nil @@ -186,6 +194,8 @@ func (b *Bouncer) String() string { builder.WriteString(b.Until.Format(time.ANSIC)) builder.WriteString(", last_pull=") builder.WriteString(b.LastPull.Format(time.ANSIC)) + builder.WriteString(", auth_type=") + builder.WriteString(b.AuthType) builder.WriteByte(')') return builder.String() } diff --git a/pkg/database/ent/bouncer/bouncer.go b/pkg/database/ent/bouncer/bouncer.go index f05ac9ac7..10fbe00b5 100644 --- a/pkg/database/ent/bouncer/bouncer.go +++ b/pkg/database/ent/bouncer/bouncer.go @@ -31,6 +31,8 @@ const ( FieldUntil = "until" // FieldLastPull holds the string denoting the last_pull field in the database. FieldLastPull = "last_pull" + // FieldAuthType holds the string denoting the auth_type field in the database. + FieldAuthType = "auth_type" // Table holds the table name of the bouncer in the database. Table = "bouncers" ) @@ -48,6 +50,7 @@ var Columns = []string{ FieldVersion, FieldUntil, FieldLastPull, + FieldAuthType, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -75,4 +78,6 @@ var ( DefaultUntil func() time.Time // DefaultLastPull holds the default value on creation for the "last_pull" field. DefaultLastPull func() time.Time + // DefaultAuthType holds the default value on creation for the "auth_type" field. + DefaultAuthType string ) diff --git a/pkg/database/ent/bouncer/where.go b/pkg/database/ent/bouncer/where.go index ae652edc3..15d16cd21 100644 --- a/pkg/database/ent/bouncer/where.go +++ b/pkg/database/ent/bouncer/where.go @@ -162,6 +162,13 @@ func LastPull(v time.Time) predicate.Bouncer { }) } +// AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ. +func AuthType(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldAuthType), v)) + }) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Bouncer { return predicate.Bouncer(func(s *sql.Selector) { @@ -1119,6 +1126,117 @@ func LastPullLTE(v time.Time) predicate.Bouncer { }) } +// AuthTypeEQ applies the EQ predicate on the "auth_type" field. +func AuthTypeEQ(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. +func AuthTypeNEQ(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeIn applies the In predicate on the "auth_type" field. +func AuthTypeIn(vs ...string) predicate.Bouncer { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Bouncer(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldAuthType), v...)) + }) +} + +// AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. +func AuthTypeNotIn(vs ...string) predicate.Bouncer { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Bouncer(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldAuthType), v...)) + }) +} + +// AuthTypeGT applies the GT predicate on the "auth_type" field. +func AuthTypeGT(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeGTE applies the GTE predicate on the "auth_type" field. +func AuthTypeGTE(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeLT applies the LT predicate on the "auth_type" field. +func AuthTypeLT(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeLTE applies the LTE predicate on the "auth_type" field. +func AuthTypeLTE(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeContains applies the Contains predicate on the "auth_type" field. +func AuthTypeContains(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. +func AuthTypeHasPrefix(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. +func AuthTypeHasSuffix(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. +func AuthTypeEqualFold(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. +func AuthTypeContainsFold(v string) predicate.Bouncer { + return predicate.Bouncer(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldAuthType), v)) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Bouncer) predicate.Bouncer { return predicate.Bouncer(func(s *sql.Selector) { diff --git a/pkg/database/ent/bouncer_create.go b/pkg/database/ent/bouncer_create.go index 2d3f2f5c2..42746542e 100644 --- a/pkg/database/ent/bouncer_create.go +++ b/pkg/database/ent/bouncer_create.go @@ -136,6 +136,20 @@ func (bc *BouncerCreate) SetNillableLastPull(t *time.Time) *BouncerCreate { return bc } +// SetAuthType sets the "auth_type" field. +func (bc *BouncerCreate) SetAuthType(s string) *BouncerCreate { + bc.mutation.SetAuthType(s) + return bc +} + +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (bc *BouncerCreate) SetNillableAuthType(s *string) *BouncerCreate { + if s != nil { + bc.SetAuthType(*s) + } + return bc +} + // Mutation returns the BouncerMutation object of the builder. func (bc *BouncerCreate) Mutation() *BouncerMutation { return bc.mutation @@ -227,6 +241,10 @@ func (bc *BouncerCreate) defaults() { v := bouncer.DefaultLastPull() bc.mutation.SetLastPull(v) } + if _, ok := bc.mutation.AuthType(); !ok { + v := bouncer.DefaultAuthType + bc.mutation.SetAuthType(v) + } } // check runs all checks and user-defined validators on the builder. @@ -243,6 +261,9 @@ func (bc *BouncerCreate) check() error { if _, ok := bc.mutation.LastPull(); !ok { return &ValidationError{Name: "last_pull", err: errors.New(`ent: missing required field "Bouncer.last_pull"`)} } + if _, ok := bc.mutation.AuthType(); !ok { + return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Bouncer.auth_type"`)} + } return nil } @@ -350,6 +371,14 @@ func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) { }) _node.LastPull = value } + if value, ok := bc.mutation.AuthType(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: bouncer.FieldAuthType, + }) + _node.AuthType = value + } return _node, _spec } diff --git a/pkg/database/ent/bouncer_update.go b/pkg/database/ent/bouncer_update.go index 571b4dcec..385fcbc7f 100644 --- a/pkg/database/ent/bouncer_update.go +++ b/pkg/database/ent/bouncer_update.go @@ -164,6 +164,20 @@ func (bu *BouncerUpdate) SetNillableLastPull(t *time.Time) *BouncerUpdate { return bu } +// SetAuthType sets the "auth_type" field. +func (bu *BouncerUpdate) SetAuthType(s string) *BouncerUpdate { + bu.mutation.SetAuthType(s) + return bu +} + +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (bu *BouncerUpdate) SetNillableAuthType(s *string) *BouncerUpdate { + if s != nil { + bu.SetAuthType(*s) + } + return bu +} + // Mutation returns the BouncerMutation object of the builder. func (bu *BouncerUpdate) Mutation() *BouncerMutation { return bu.mutation @@ -360,6 +374,13 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) { Column: bouncer.FieldLastPull, }) } + if value, ok := bu.mutation.AuthType(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: bouncer.FieldAuthType, + }) + } if n, err = sqlgraph.UpdateNodes(ctx, bu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{bouncer.Label} @@ -515,6 +536,20 @@ func (buo *BouncerUpdateOne) SetNillableLastPull(t *time.Time) *BouncerUpdateOne return buo } +// SetAuthType sets the "auth_type" field. +func (buo *BouncerUpdateOne) SetAuthType(s string) *BouncerUpdateOne { + buo.mutation.SetAuthType(s) + return buo +} + +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (buo *BouncerUpdateOne) SetNillableAuthType(s *string) *BouncerUpdateOne { + if s != nil { + buo.SetAuthType(*s) + } + return buo +} + // Mutation returns the BouncerMutation object of the builder. func (buo *BouncerUpdateOne) Mutation() *BouncerMutation { return buo.mutation @@ -735,6 +770,13 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e Column: bouncer.FieldLastPull, }) } + if value, ok := buo.mutation.AuthType(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: bouncer.FieldAuthType, + }) + } _node = &Bouncer{config: buo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/pkg/database/ent/machine.go b/pkg/database/ent/machine.go index 44d198510..b1096be6b 100644 --- a/pkg/database/ent/machine.go +++ b/pkg/database/ent/machine.go @@ -38,6 +38,8 @@ type Machine struct { IsValidated bool `json:"isValidated,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` + // AuthType holds the value of the "auth_type" field. + AuthType string `json:"auth_type"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the MachineQuery when eager-loading is set. Edges MachineEdges `json:"edges"` @@ -70,7 +72,7 @@ func (*Machine) scanValues(columns []string) ([]interface{}, error) { values[i] = new(sql.NullBool) case machine.FieldID: values[i] = new(sql.NullInt64) - case machine.FieldMachineId, machine.FieldPassword, machine.FieldIpAddress, machine.FieldScenarios, machine.FieldVersion, machine.FieldStatus: + case machine.FieldMachineId, machine.FieldPassword, machine.FieldIpAddress, machine.FieldScenarios, machine.FieldVersion, machine.FieldStatus, machine.FieldAuthType: values[i] = new(sql.NullString) case machine.FieldCreatedAt, machine.FieldUpdatedAt, machine.FieldLastPush, machine.FieldLastHeartbeat: values[i] = new(sql.NullTime) @@ -165,6 +167,12 @@ func (m *Machine) assignValues(columns []string, values []interface{}) error { } else if value.Valid { m.Status = value.String } + case machine.FieldAuthType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field auth_type", values[i]) + } else if value.Valid { + m.AuthType = value.String + } } } return nil @@ -227,6 +235,8 @@ func (m *Machine) String() string { builder.WriteString(fmt.Sprintf("%v", m.IsValidated)) builder.WriteString(", status=") builder.WriteString(m.Status) + builder.WriteString(", auth_type=") + builder.WriteString(m.AuthType) builder.WriteByte(')') return builder.String() } diff --git a/pkg/database/ent/machine/machine.go b/pkg/database/ent/machine/machine.go index b8e6c71fe..efed1eaab 100644 --- a/pkg/database/ent/machine/machine.go +++ b/pkg/database/ent/machine/machine.go @@ -33,6 +33,8 @@ const ( FieldIsValidated = "is_validated" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // FieldAuthType holds the string denoting the auth_type field in the database. + FieldAuthType = "auth_type" // EdgeAlerts holds the string denoting the alerts edge name in mutations. EdgeAlerts = "alerts" // Table holds the table name of the machine in the database. @@ -60,6 +62,7 @@ var Columns = []string{ FieldVersion, FieldIsValidated, FieldStatus, + FieldAuthType, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -93,4 +96,6 @@ var ( ScenariosValidator func(string) error // DefaultIsValidated holds the default value on creation for the "isValidated" field. DefaultIsValidated bool + // DefaultAuthType holds the default value on creation for the "auth_type" field. + DefaultAuthType string ) diff --git a/pkg/database/ent/machine/where.go b/pkg/database/ent/machine/where.go index 361405e8e..0f77047bd 100644 --- a/pkg/database/ent/machine/where.go +++ b/pkg/database/ent/machine/where.go @@ -170,6 +170,13 @@ func Status(v string) predicate.Machine { }) } +// AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ. +func AuthType(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldAuthType), v)) + }) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Machine { return predicate.Machine(func(s *sql.Selector) { @@ -1252,6 +1259,117 @@ func StatusContainsFold(v string) predicate.Machine { }) } +// AuthTypeEQ applies the EQ predicate on the "auth_type" field. +func AuthTypeEQ(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeNEQ applies the NEQ predicate on the "auth_type" field. +func AuthTypeNEQ(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeIn applies the In predicate on the "auth_type" field. +func AuthTypeIn(vs ...string) predicate.Machine { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Machine(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldAuthType), v...)) + }) +} + +// AuthTypeNotIn applies the NotIn predicate on the "auth_type" field. +func AuthTypeNotIn(vs ...string) predicate.Machine { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Machine(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldAuthType), v...)) + }) +} + +// AuthTypeGT applies the GT predicate on the "auth_type" field. +func AuthTypeGT(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeGTE applies the GTE predicate on the "auth_type" field. +func AuthTypeGTE(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeLT applies the LT predicate on the "auth_type" field. +func AuthTypeLT(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeLTE applies the LTE predicate on the "auth_type" field. +func AuthTypeLTE(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeContains applies the Contains predicate on the "auth_type" field. +func AuthTypeContains(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field. +func AuthTypeHasPrefix(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field. +func AuthTypeHasSuffix(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field. +func AuthTypeEqualFold(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldAuthType), v)) + }) +} + +// AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field. +func AuthTypeContainsFold(v string) predicate.Machine { + return predicate.Machine(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldAuthType), v)) + }) +} + // HasAlerts applies the HasEdge predicate on the "alerts" edge. func HasAlerts() predicate.Machine { return predicate.Machine(func(s *sql.Selector) { diff --git a/pkg/database/ent/machine_create.go b/pkg/database/ent/machine_create.go index 2d76a3fc9..3f0369050 100644 --- a/pkg/database/ent/machine_create.go +++ b/pkg/database/ent/machine_create.go @@ -151,6 +151,20 @@ func (mc *MachineCreate) SetNillableStatus(s *string) *MachineCreate { return mc } +// SetAuthType sets the "auth_type" field. +func (mc *MachineCreate) SetAuthType(s string) *MachineCreate { + mc.mutation.SetAuthType(s) + return mc +} + +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (mc *MachineCreate) SetNillableAuthType(s *string) *MachineCreate { + if s != nil { + mc.SetAuthType(*s) + } + return mc +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. func (mc *MachineCreate) AddAlertIDs(ids ...int) *MachineCreate { mc.mutation.AddAlertIDs(ids...) @@ -257,6 +271,10 @@ func (mc *MachineCreate) defaults() { v := machine.DefaultIsValidated mc.mutation.SetIsValidated(v) } + if _, ok := mc.mutation.AuthType(); !ok { + v := machine.DefaultAuthType + mc.mutation.SetAuthType(v) + } } // check runs all checks and user-defined validators on the builder. @@ -278,6 +296,9 @@ func (mc *MachineCreate) check() error { if _, ok := mc.mutation.IsValidated(); !ok { return &ValidationError{Name: "isValidated", err: errors.New(`ent: missing required field "Machine.isValidated"`)} } + if _, ok := mc.mutation.AuthType(); !ok { + return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Machine.auth_type"`)} + } return nil } @@ -393,6 +414,14 @@ func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) { }) _node.Status = value } + if value, ok := mc.mutation.AuthType(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: machine.FieldAuthType, + }) + _node.AuthType = value + } if nodes := mc.mutation.AlertsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/pkg/database/ent/machine_update.go b/pkg/database/ent/machine_update.go index b396ef1d4..afdb89a8f 100644 --- a/pkg/database/ent/machine_update.go +++ b/pkg/database/ent/machine_update.go @@ -169,6 +169,20 @@ func (mu *MachineUpdate) ClearStatus() *MachineUpdate { return mu } +// SetAuthType sets the "auth_type" field. +func (mu *MachineUpdate) SetAuthType(s string) *MachineUpdate { + mu.mutation.SetAuthType(s) + return mu +} + +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (mu *MachineUpdate) SetNillableAuthType(s *string) *MachineUpdate { + if s != nil { + mu.SetAuthType(*s) + } + return mu +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. func (mu *MachineUpdate) AddAlertIDs(ids ...int) *MachineUpdate { mu.mutation.AddAlertIDs(ids...) @@ -438,6 +452,13 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) { Column: machine.FieldStatus, }) } + if value, ok := mu.mutation.AuthType(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: machine.FieldAuthType, + }) + } if mu.mutation.AlertsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -651,6 +672,20 @@ func (muo *MachineUpdateOne) ClearStatus() *MachineUpdateOne { return muo } +// SetAuthType sets the "auth_type" field. +func (muo *MachineUpdateOne) SetAuthType(s string) *MachineUpdateOne { + muo.mutation.SetAuthType(s) + return muo +} + +// SetNillableAuthType sets the "auth_type" field if the given value is not nil. +func (muo *MachineUpdateOne) SetNillableAuthType(s *string) *MachineUpdateOne { + if s != nil { + muo.SetAuthType(*s) + } + return muo +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. func (muo *MachineUpdateOne) AddAlertIDs(ids ...int) *MachineUpdateOne { muo.mutation.AddAlertIDs(ids...) @@ -944,6 +979,13 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e Column: machine.FieldStatus, }) } + if value, ok := muo.mutation.AuthType(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: machine.FieldAuthType, + }) + } if muo.mutation.AlertsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/pkg/database/ent/migrate/schema.go b/pkg/database/ent/migrate/schema.go index 891fdee89..5e0dc1f3c 100644 --- a/pkg/database/ent/migrate/schema.go +++ b/pkg/database/ent/migrate/schema.go @@ -69,6 +69,7 @@ var ( {Name: "version", Type: field.TypeString, Nullable: true}, {Name: "until", Type: field.TypeTime, Nullable: true}, {Name: "last_pull", Type: field.TypeTime}, + {Name: "auth_type", Type: field.TypeString, Default: "api-key"}, } // BouncersTable holds the schema information for the "bouncers" table. BouncersTable = &schema.Table{ @@ -163,6 +164,7 @@ var ( {Name: "version", Type: field.TypeString, Nullable: true}, {Name: "is_validated", Type: field.TypeBool, Default: false}, {Name: "status", Type: field.TypeString, Nullable: true}, + {Name: "auth_type", Type: field.TypeString, Default: "password"}, } // MachinesTable holds the schema information for the "machines" table. MachinesTable = &schema.Table{ diff --git a/pkg/database/ent/mutation.go b/pkg/database/ent/mutation.go index 7427e6b28..8d73b353e 100644 --- a/pkg/database/ent/mutation.go +++ b/pkg/database/ent/mutation.go @@ -2338,6 +2338,7 @@ type BouncerMutation struct { version *string until *time.Time last_pull *time.Time + auth_type *string clearedFields map[string]struct{} done bool oldValue func(context.Context) (*Bouncer, error) @@ -2880,6 +2881,42 @@ func (m *BouncerMutation) ResetLastPull() { m.last_pull = nil } +// SetAuthType sets the "auth_type" field. +func (m *BouncerMutation) SetAuthType(s string) { + m.auth_type = &s +} + +// AuthType returns the value of the "auth_type" field in the mutation. +func (m *BouncerMutation) AuthType() (r string, exists bool) { + v := m.auth_type + if v == nil { + return + } + return *v, true +} + +// OldAuthType returns the old "auth_type" field's value of the Bouncer entity. +// If the Bouncer object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *BouncerMutation) OldAuthType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAuthType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAuthType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAuthType: %w", err) + } + return oldValue.AuthType, nil +} + +// ResetAuthType resets all changes to the "auth_type" field. +func (m *BouncerMutation) ResetAuthType() { + m.auth_type = nil +} + // Where appends a list predicates to the BouncerMutation builder. func (m *BouncerMutation) Where(ps ...predicate.Bouncer) { m.predicates = append(m.predicates, ps...) @@ -2899,7 +2936,7 @@ func (m *BouncerMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *BouncerMutation) Fields() []string { - fields := make([]string, 0, 10) + fields := make([]string, 0, 11) if m.created_at != nil { fields = append(fields, bouncer.FieldCreatedAt) } @@ -2930,6 +2967,9 @@ func (m *BouncerMutation) Fields() []string { if m.last_pull != nil { fields = append(fields, bouncer.FieldLastPull) } + if m.auth_type != nil { + fields = append(fields, bouncer.FieldAuthType) + } return fields } @@ -2958,6 +2998,8 @@ func (m *BouncerMutation) Field(name string) (ent.Value, bool) { return m.Until() case bouncer.FieldLastPull: return m.LastPull() + case bouncer.FieldAuthType: + return m.AuthType() } return nil, false } @@ -2987,6 +3029,8 @@ func (m *BouncerMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldUntil(ctx) case bouncer.FieldLastPull: return m.OldLastPull(ctx) + case bouncer.FieldAuthType: + return m.OldAuthType(ctx) } return nil, fmt.Errorf("unknown Bouncer field %s", name) } @@ -3066,6 +3110,13 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error { } m.SetLastPull(v) return nil + case bouncer.FieldAuthType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAuthType(v) + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } @@ -3184,6 +3235,9 @@ func (m *BouncerMutation) ResetField(name string) error { case bouncer.FieldLastPull: m.ResetLastPull() return nil + case bouncer.FieldAuthType: + m.ResetAuthType() + return nil } return fmt.Errorf("unknown Bouncer field %s", name) } @@ -5246,6 +5300,7 @@ type MachineMutation struct { version *string isValidated *bool status *string + auth_type *string clearedFields map[string]struct{} alerts map[int]struct{} removedalerts map[int]struct{} @@ -5840,6 +5895,42 @@ func (m *MachineMutation) ResetStatus() { delete(m.clearedFields, machine.FieldStatus) } +// SetAuthType sets the "auth_type" field. +func (m *MachineMutation) SetAuthType(s string) { + m.auth_type = &s +} + +// AuthType returns the value of the "auth_type" field in the mutation. +func (m *MachineMutation) AuthType() (r string, exists bool) { + v := m.auth_type + if v == nil { + return + } + return *v, true +} + +// OldAuthType returns the old "auth_type" field's value of the Machine entity. +// If the Machine object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *MachineMutation) OldAuthType(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAuthType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAuthType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAuthType: %w", err) + } + return oldValue.AuthType, nil +} + +// ResetAuthType resets all changes to the "auth_type" field. +func (m *MachineMutation) ResetAuthType() { + m.auth_type = nil +} + // AddAlertIDs adds the "alerts" edge to the Alert entity by ids. func (m *MachineMutation) AddAlertIDs(ids ...int) { if m.alerts == nil { @@ -5913,7 +6004,7 @@ func (m *MachineMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *MachineMutation) Fields() []string { - fields := make([]string, 0, 11) + fields := make([]string, 0, 12) if m.created_at != nil { fields = append(fields, machine.FieldCreatedAt) } @@ -5947,6 +6038,9 @@ func (m *MachineMutation) Fields() []string { if m.status != nil { fields = append(fields, machine.FieldStatus) } + if m.auth_type != nil { + fields = append(fields, machine.FieldAuthType) + } return fields } @@ -5977,6 +6071,8 @@ func (m *MachineMutation) Field(name string) (ent.Value, bool) { return m.IsValidated() case machine.FieldStatus: return m.Status() + case machine.FieldAuthType: + return m.AuthType() } return nil, false } @@ -6008,6 +6104,8 @@ func (m *MachineMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldIsValidated(ctx) case machine.FieldStatus: return m.OldStatus(ctx) + case machine.FieldAuthType: + return m.OldAuthType(ctx) } return nil, fmt.Errorf("unknown Machine field %s", name) } @@ -6094,6 +6192,13 @@ func (m *MachineMutation) SetField(name string, value ent.Value) error { } m.SetStatus(v) return nil + case machine.FieldAuthType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAuthType(v) + return nil } return fmt.Errorf("unknown Machine field %s", name) } @@ -6221,6 +6326,9 @@ func (m *MachineMutation) ResetField(name string) error { case machine.FieldStatus: m.ResetStatus() return nil + case machine.FieldAuthType: + m.ResetAuthType() + return nil } return fmt.Errorf("unknown Machine field %s", name) } diff --git a/pkg/database/ent/runtime.go b/pkg/database/ent/runtime.go index c12c025ce..0bc4b1c45 100644 --- a/pkg/database/ent/runtime.go +++ b/pkg/database/ent/runtime.go @@ -82,6 +82,10 @@ func init() { bouncerDescLastPull := bouncerFields[9].Descriptor() // bouncer.DefaultLastPull holds the default value on creation for the last_pull field. bouncer.DefaultLastPull = bouncerDescLastPull.Default.(func() time.Time) + // bouncerDescAuthType is the schema descriptor for auth_type field. + bouncerDescAuthType := bouncerFields[10].Descriptor() + // bouncer.DefaultAuthType holds the default value on creation for the auth_type field. + bouncer.DefaultAuthType = bouncerDescAuthType.Default.(string) decisionFields := schema.Decision{}.Fields() _ = decisionFields // decisionDescCreatedAt is the schema descriptor for created_at field. @@ -152,6 +156,10 @@ func init() { machineDescIsValidated := machineFields[9].Descriptor() // machine.DefaultIsValidated holds the default value on creation for the isValidated field. machine.DefaultIsValidated = machineDescIsValidated.Default.(bool) + // machineDescAuthType is the schema descriptor for auth_type field. + machineDescAuthType := machineFields[11].Descriptor() + // machine.DefaultAuthType holds the default value on creation for the auth_type field. + machine.DefaultAuthType = machineDescAuthType.Default.(string) metaFields := schema.Meta{}.Fields() _ = metaFields // metaDescCreatedAt is the schema descriptor for created_at field. diff --git a/pkg/database/ent/schema/bouncer.go b/pkg/database/ent/schema/bouncer.go index 4efacaa8a..c30812912 100644 --- a/pkg/database/ent/schema/bouncer.go +++ b/pkg/database/ent/schema/bouncer.go @@ -29,6 +29,7 @@ func (Bouncer) Fields() []ent.Field { field.Time("until").Default(types.UtcNow).Optional().StructTag(`json:"until"`), field.Time("last_pull"). Default(types.UtcNow).StructTag(`json:"last_pull"`), + field.String("auth_type").StructTag(`json:"auth_type"`).Default(types.ApiKeyAuthType), } } diff --git a/pkg/database/ent/schema/machine.go b/pkg/database/ent/schema/machine.go index 4512bdf4a..f711bc612 100644 --- a/pkg/database/ent/schema/machine.go +++ b/pkg/database/ent/schema/machine.go @@ -35,6 +35,7 @@ func (Machine) Fields() []ent.Field { field.Bool("isValidated"). Default(false), field.String("status").Optional(), + field.String("auth_type").Default(types.PasswordAuthType).StructTag(`json:"auth_type"`), } } diff --git a/pkg/database/machines.go b/pkg/database/machines.go index ab4505cba..391c1103a 100644 --- a/pkg/database/machines.go +++ b/pkg/database/machines.go @@ -14,11 +14,11 @@ import ( const CapiMachineID = "CAPI" -func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool) (int, error) { +func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) { hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) if err != nil { c.Log.Warningf("CreateMachine : %s", err) - return 0, errors.Wrap(HashError, "") + return nil, errors.Wrap(HashError, "") } machineExist, err := c.Ent.Machine. @@ -26,34 +26,39 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA Where(machine.MachineIdEQ(*machineID)). Select(machine.FieldMachineId).Strings(c.CTX) if err != nil { - return 0, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) + return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) } if len(machineExist) > 0 { if force { _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX) if err != nil { c.Log.Warningf("CreateMachine : %s", err) - return 0, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) + return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) } - return 1, nil + machine, err := c.QueryMachineByID(*machineID) + if err != nil { + return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) + } + return machine, nil } - return 0, errors.Wrapf(UserExists, "user '%s'", *machineID) + return nil, errors.Wrapf(UserExists, "user '%s'", *machineID) } - _, err = c.Ent.Machine. + machine, err := c.Ent.Machine. Create(). SetMachineId(*machineID). SetPassword(string(hashPassword)). SetIpAddress(ipAddress). SetIsValidated(isValidated). + SetAuthType(authType). Save(c.CTX) if err != nil { c.Log.Warningf("CreateMachine : %s", err) - return 0, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID) + return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID) } - return 1, nil + return machine, nil } func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { diff --git a/pkg/types/constants.go b/pkg/types/constants.go new file mode 100644 index 000000000..3fe83de21 --- /dev/null +++ b/pkg/types/constants.go @@ -0,0 +1,5 @@ +package types + +const ApiKeyAuthType = "api-key" +const TlsAuthType = "tls" +const PasswordAuthType = "password" diff --git a/tests/bats/06_crowdsec.bats b/tests/bats/06_crowdsec.bats index 626a078b7..5446c9042 100755 --- a/tests/bats/06_crowdsec.bats +++ b/tests/bats/06_crowdsec.bats @@ -37,5 +37,5 @@ declare stderr @test "${FILE} CS_LAPI_SECRET not strong enough" { CS_LAPI_SECRET=foo run -1 --separate-stderr timeout 2s "${CROWDSEC}" run -0 echo "${stderr}" - assert_output --partial "api server init: unable to run local API: CS_LAPI_SECRET not strong enough" + assert_output --partial "api server init: unable to run local API: controller init: CS_LAPI_SECRET not strong enough" } diff --git a/tests/bats/11_bouncers_tls.bats b/tests/bats/11_bouncers_tls.bats new file mode 100644 index 000000000..8969e53c7 --- /dev/null +++ b/tests/bats/11_bouncers_tls.bats @@ -0,0 +1,97 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +config_disable_agent() { + yq 'del(.crowdsec_service)' -i "${CONFIG_YAML}" +} + +setup_file() { + load "../lib/setup_file.sh" + ./instance-data load + tmpdir=$(mktemp -d) + export tmpdir + #gen the CA + cfssl gencert --initca ./cfssl/ca.json 2>/dev/null | cfssljson --bare "${tmpdir}/ca" + #gen an intermediate + cfssl gencert --initca ./cfssl/intermediate.json 2>/dev/null | cfssljson --bare "${tmpdir}/inter" + cfssl sign -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config ./cfssl/profiles.json -profile intermediate_ca "${tmpdir}/inter.csr" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" + #gen server cert for crowdsec with the intermediate + cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=server ./cfssl/server.json 2>/dev/null | cfssljson --bare "${tmpdir}/server" + #gen client cert for the bouncer + cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/bouncer.json 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer" + #gen client cert for the bouncer with an invalid OU + cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/bouncer_invalid.json 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_bad_ou" + #gen client cert for the bouncer directly signed by the CA, it should be refused by crowdsec as uses the intermediate + cfssl gencert -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/bouncer.json 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_invalid" + + cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/bouncer.json 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_revoked" + serial="$(openssl x509 -noout -serial -in ${tmpdir}/bouncer_revoked.pem | cut -d '=' -f2)" + echo "ibase=16; $serial" | bc > "${tmpdir}/serials.txt" + cfssl gencrl "${tmpdir}/serials.txt" "${tmpdir}/ca.pem" "${tmpdir}/ca-key.pem" | base64 -d | openssl crl -inform DER -out "${tmpdir}/crl.pem" + + + yq ' + .api.server.tls.cert_file=strenv(tmpdir) + "/server.pem" | + .api.server.tls.key_file=strenv(tmpdir) + "/server-key.pem" | + .api.server.tls.ca_cert_path=strenv(tmpdir) + "/inter.pem" | + .api.server.tls.crl_path=strenv(tmpdir) + "/crl.pem" | + .api.server.tls.bouncers_allowed_ou=["bouncer-ou"] + ' -i "${CONFIG_YAML}" + + config_disable_agent +} + + +teardown_file() { + load "../lib/teardown_file.sh" + rm -rf $tmpdir +} + +setup() { + load "../lib/setup.sh" + ./instance-crowdsec start +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "$FILE there are 0 bouncers" { + run -0 cscli bouncers list -o json + assert_output "[]" +} + +@test "$FILE simulate one bouncer request with a valid cert" { + run -0 curl -s --cert "${tmpdir}/bouncer.pem" --key "${tmpdir}/bouncer-key.pem" --cacert "${tmpdir}/inter.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 + assert_output "null" + run -0 cscli bouncers list -o json + run -0 jq '. | length' <(output) + assert_output '1' + run -0 cscli bouncers list -o json + run -0 jq -r '.[] | .name' <(output) + assert_output "localhost@127.0.0.1" + run cscli bouncers delete localhost@127.0.0.1 +} + +@test "$FILE simulate one bouncer request with an invalid cert" { + run curl -s --cert "${tmpdir}/bouncer_invalid.pem" --key "${tmpdir}/bouncer_invalid-key.pem" --cacert "${tmpdir}/ca-key.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 + run -0 cscli bouncers list -o json + assert_output "[]" +} + +@test "$FILE simulate one bouncer request with an invalid OU" { + run curl -s --cert "${tmpdir}/bouncer_bad_ou.pem" --key "${tmpdir}/bouncer_bad_ou-key.pem" --cacert "${tmpdir}/inter.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 + run -0 cscli bouncers list -o json + assert_output "[]" +} + +@test "$FILE simulate one bouncer request with a revoked certificate" { + run -0 curl -i -s --cert "${tmpdir}/bouncer_revoked.pem" --key "${tmpdir}/bouncer_revoked-key.pem" --cacert "${tmpdir}/inter.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42 + assert_output --partial "access forbidden" + run -0 cscli bouncers list -o json + assert_output "[]" +} \ No newline at end of file diff --git a/tests/bats/30_machines_tls.bats b/tests/bats/30_machines_tls.bats new file mode 100644 index 000000000..7ff4ec140 --- /dev/null +++ b/tests/bats/30_machines_tls.bats @@ -0,0 +1,136 @@ +#!/usr/bin/env bats +# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si: + +set -u + +setup_file() { + load "../lib/setup_file.sh" + ./instance-data load + tmpdir=$(mktemp -d) + export tmpdir + #gen the CA + cfssl gencert --initca ./cfssl/ca.json 2>/dev/null | cfssljson --bare "${tmpdir}/ca" + #gen an intermediate + cfssl gencert --initca ./cfssl/intermediate.json 2>/dev/null | cfssljson --bare "${tmpdir}/inter" + cfssl sign -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config ./cfssl/profiles.json -profile intermediate_ca "${tmpdir}/inter.csr" 2>/dev/null | cfssljson --bare "${tmpdir}/inter" + #gen server cert for crowdsec with the intermediate + cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=server ./cfssl/server.json 2>/dev/null | cfssljson --bare "${tmpdir}/server" + #gen client cert for the agent + cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/agent.json 2>/dev/null | cfssljson --bare "${tmpdir}/agent" + #gen client cert for the agent with an invalid OU + cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/agent_invalid.json 2>/dev/null | cfssljson --bare "${tmpdir}/agent_bad_ou" + #gen client cert for the agent directly signed by the CA, it should be refused by crowdsec as uses the intermediate + cfssl gencert -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/agent.json 2>/dev/null | cfssljson --bare "${tmpdir}/agent_invalid" + + cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/agent.json 2>/dev/null | cfssljson --bare "${tmpdir}/agent_revoked" + serial="$(openssl x509 -noout -serial -in ${tmpdir}/agent_revoked.pem | cut -d '=' -f2)" + echo "ibase=16; $serial" | bc > "${tmpdir}/serials.txt" + cfssl gencrl "${tmpdir}/serials.txt" "${tmpdir}/ca.pem" "${tmpdir}/ca-key.pem" | base64 -d | openssl crl -inform DER -out "${tmpdir}/crl.pem" + + + yq ' + .api.server.tls.cert_file=strenv(tmpdir) + "/server.pem" | + .api.server.tls.key_file=strenv(tmpdir) + "/server-key.pem" | + .api.server.tls.ca_cert_path=strenv(tmpdir) + "/inter.pem" | + .api.server.tls.crl_path=strenv(tmpdir) + "/crl.pem" | + .api.server.tls.agents_allowed_ou=["agent-ou"] + ' -i "${CONFIG_YAML}" + +} + +teardown_file() { + load "../lib/teardown_file.sh" +} + +setup() { + load "../lib/setup.sh" + cscli machines delete githubciXXXXXXXXXXXXXXXXXXXXXXXX +} + +teardown() { + ./instance-crowdsec stop +} + +#---------- + +@test "$FILE invalid OU for agent" { + CONFIG_DIR=$(dirname ${CONFIG_YAML}) + + yq ' + .ca_cert_path=strenv(tmpdir) + "/inter.pem" | + .key_path=strenv(tmpdir) + "/agent_bad_ou-key.pem" | + .cert_path=strenv(tmpdir) + "/agent_bad_ou.pem" | + .url="https://127.0.0.1:8080" + ' -i "${CONFIG_DIR}/local_api_credentials.yaml" + + yq 'del(.login)' -i "${CONFIG_DIR}/local_api_credentials.yaml" + yq 'del(.password)' -i "${CONFIG_DIR}/local_api_credentials.yaml" + ./instance-crowdsec start + #let the agent start + sleep 2 + run -0 cscli machines list -o json + assert_output '[]' +} + +@test "$FILE we have exactly one machine registered with TLS" { + CONFIG_DIR=$(dirname ${CONFIG_YAML}) + + yq ' + .ca_cert_path=strenv(tmpdir) + "/inter.pem" | + .key_path=strenv(tmpdir) + "/agent-key.pem" | + .cert_path=strenv(tmpdir) + "/agent.pem" | + .url="https://127.0.0.1:8080" + ' -i "${CONFIG_DIR}/local_api_credentials.yaml" + + yq 'del(.login)' -i "${CONFIG_DIR}/local_api_credentials.yaml" + yq 'del(.password)' -i "${CONFIG_DIR}/local_api_credentials.yaml" + ./instance-crowdsec start + #let the agent start + sleep 2 + run -0 cscli machines list -o json + run -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress, .[0].auth_type]' <(output) + + assert_output '[1,"localhost@127.0.0.1",true,"127.0.0.1","tls"]' + cscli machines delete localhost@127.0.0.1 + + ./instance-crowdsec stop +} + + +@test "$FILE invalid cert for agent" { + CONFIG_DIR=$(dirname ${CONFIG_YAML}) + + yq ' + .ca_cert_path=strenv(tmpdir) + "/inter.pem" | + .key_path=strenv(tmpdir) + "/agent_invalid-key.pem" | + .cert_path=strenv(tmpdir) + "/agent_invalid.pem" | + .url="https://127.0.0.1:8080" + ' -i "${CONFIG_DIR}/local_api_credentials.yaml" + + yq 'del(.login)' -i "${CONFIG_DIR}/local_api_credentials.yaml" + yq 'del(.password)' -i "${CONFIG_DIR}/local_api_credentials.yaml" + ./instance-crowdsec start + #let the agent start + sleep 2 + run -0 cscli machines list -o json + assert_output '[]' +} + +@test "$FILE revoked cert for agent" { + CONFIG_DIR=$(dirname ${CONFIG_YAML}) + + yq ' + .ca_cert_path=strenv(tmpdir) + "/inter.pem" | + .key_path=strenv(tmpdir) + "/agent_revoked-key.pem" | + .cert_path=strenv(tmpdir) + "/agent_revoked.pem" | + .url="https://127.0.0.1:8080" + ' -i "${CONFIG_DIR}/local_api_credentials.yaml" + + yq 'del(.login)' -i "${CONFIG_DIR}/local_api_credentials.yaml" + yq 'del(.password)' -i "${CONFIG_DIR}/local_api_credentials.yaml" + ./instance-crowdsec start + #let the agent start + sleep 2 + run -0 cscli machines list -o json + assert_output '[]' +} \ No newline at end of file diff --git a/tests/cfssl/agent.json b/tests/cfssl/agent.json new file mode 100644 index 000000000..693e3aa51 --- /dev/null +++ b/tests/cfssl/agent.json @@ -0,0 +1,16 @@ +{ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ + { + "C": "FR", + "L": "Paris", + "O": "Crowdsec", + "OU": "agent-ou", + "ST": "France" + } + ] + } \ No newline at end of file diff --git a/tests/cfssl/agent_invalid.json b/tests/cfssl/agent_invalid.json new file mode 100644 index 000000000..c61d4dee6 --- /dev/null +++ b/tests/cfssl/agent_invalid.json @@ -0,0 +1,16 @@ +{ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ + { + "C": "FR", + "L": "Paris", + "O": "Crowdsec", + "OU": "this-is-not-the-ou-youre-looking-for", + "ST": "France" + } + ] + } \ No newline at end of file diff --git a/tests/cfssl/bouncer.json b/tests/cfssl/bouncer.json new file mode 100644 index 000000000..9a07f5766 --- /dev/null +++ b/tests/cfssl/bouncer.json @@ -0,0 +1,16 @@ +{ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ + { + "C": "FR", + "L": "Paris", + "O": "Crowdsec", + "OU": "bouncer-ou", + "ST": "France" + } + ] + } \ No newline at end of file diff --git a/tests/cfssl/bouncer_invalid.json b/tests/cfssl/bouncer_invalid.json new file mode 100644 index 000000000..c61d4dee6 --- /dev/null +++ b/tests/cfssl/bouncer_invalid.json @@ -0,0 +1,16 @@ +{ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ + { + "C": "FR", + "L": "Paris", + "O": "Crowdsec", + "OU": "this-is-not-the-ou-youre-looking-for", + "ST": "France" + } + ] + } \ No newline at end of file diff --git a/tests/cfssl/ca.json b/tests/cfssl/ca.json new file mode 100644 index 000000000..ed907e037 --- /dev/null +++ b/tests/cfssl/ca.json @@ -0,0 +1,16 @@ +{ + "CN": "CrowdSec Test CA", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ + { + "C": "FR", + "L": "Paris", + "O": "Crowdsec", + "OU": "Crowdsec", + "ST": "France" + } + ] +} \ No newline at end of file diff --git a/tests/cfssl/intermediate.json b/tests/cfssl/intermediate.json new file mode 100644 index 000000000..3996ce6e1 --- /dev/null +++ b/tests/cfssl/intermediate.json @@ -0,0 +1,19 @@ +{ + "CN": "CrowdSec Test CA Intermediate", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ + { + "C": "FR", + "L": "Paris", + "O": "Crowdsec", + "OU": "Crowdsec Intermediate", + "ST": "France" + } + ], + "ca": { + "expiry": "42720h" + } + } \ No newline at end of file diff --git a/tests/cfssl/profiles.json b/tests/cfssl/profiles.json new file mode 100644 index 000000000..d0dfced4a --- /dev/null +++ b/tests/cfssl/profiles.json @@ -0,0 +1,44 @@ +{ + "signing": { + "default": { + "expiry": "8760h" + }, + "profiles": { + "intermediate_ca": { + "usages": [ + "signing", + "digital signature", + "key encipherment", + "cert sign", + "crl sign", + "server auth", + "client auth" + ], + "expiry": "8760h", + "ca_constraint": { + "is_ca": true, + "max_path_len": 0, + "max_path_len_zero": true + } + }, + "server": { + "usages": [ + "signing", + "digital signing", + "key encipherment", + "server auth" + ], + "expiry": "8760h" + }, + "client": { + "usages": [ + "signing", + "digital signature", + "key encipherment", + "client auth" + ], + "expiry": "8760h" + } + } + } + } \ No newline at end of file diff --git a/tests/cfssl/server.json b/tests/cfssl/server.json new file mode 100644 index 000000000..37018259e --- /dev/null +++ b/tests/cfssl/server.json @@ -0,0 +1,20 @@ +{ + "CN": "localhost", + "key": { + "algo": "rsa", + "size": 2048 + }, + "names": [ + { + "C": "FR", + "L": "Paris", + "O": "Crowdsec", + "OU": "Crowdsec Server", + "ST": "France" + } + ], + "hosts": [ + "127.0.0.1", + "localhost" + ] + } \ No newline at end of file diff --git a/tests/check-requirements b/tests/check-requirements index 2a820bc83..744bf7532 100755 --- a/tests/check-requirements +++ b/tests/check-requirements @@ -68,6 +68,22 @@ check_daemonizer() { esac } +check_cfssl() { + # shellcheck disable=SC2016 + howto_install='You can install it with "go get -u github.com/cloudflare/cfssl/cmd/cfssl" and add ~/go/bin to $PATH.' + if ! command -v cfssl >/dev/null; then + die "Missing required program 'cfssl'. $howto_install" + fi +} + +check_cfssljson() { + # shellcheck disable=SC2016 + howto_install='You can install it with "go get -u github.com/cloudflare/cfssl/cmd/cfssljson" and add ~/go/bin to $PATH.' + if ! command -v cfssljson >/dev/null; then + die "Missing required program 'cfssljson'. $howto_install" + fi +} + check_gocovmerge() { if ! command -v gocovmerge >/dev/null; then die "missing required program 'gocovmerge'. You can install it with \"go install github.com/wadey/gocovmerge@latest\"" @@ -76,6 +92,8 @@ check_gocovmerge() { check_bats_core check_daemonizer +check_cfssl +check_cfssljson check_jq check_nc check_python3