Add support for certificate authentication for agents and bouncers (#1428)

This commit is contained in:
Thibault "bui" Koechlin 2022-06-08 16:05:52 +02:00 committed by GitHub
parent bdda8691ff
commit 1c0fe09576
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 1985 additions and 218 deletions

View file

@ -35,8 +35,12 @@ jobs:
- name: "Install bats dependencies" - name: "Install bats dependencies"
run: | run: |
sudo apt install -y -qq build-essential daemonize jq netcat-openbsd sudo apt install -y -qq build-essential daemonize jq netcat-openbsd
GO111MODULE=on go get github.com/mikefarah/yq/v4 go install github.com/mikefarah/yq/v4@latest
go install github.com/cloudflare/cfssl/cmd/cfssl@latest
go install github.com/cloudflare/cfssl/cmd/cfssljson@latest
sudo cp -u ~/go/bin/yq /usr/local/bin/ sudo cp -u ~/go/bin/yq /usr/local/bin/
sudo cp -u ~/go/bin/cfssl /usr/local/bin
sudo cp -u ~/go/bin/cfssljson /usr/local/bin
- name: "Build crowdsec and fixture" - name: "Build crowdsec and fixture"
run: make bats-clean bats-build bats-fixture run: make bats-clean bats-build bats-fixture

View file

@ -46,8 +46,12 @@ jobs:
- name: "Install bats dependencies" - name: "Install bats dependencies"
run: | run: |
sudo apt install -y -qq build-essential daemonize jq netcat-openbsd sudo apt install -y -qq build-essential daemonize jq netcat-openbsd
GO111MODULE=on go get github.com/mikefarah/yq/v4 go install github.com/mikefarah/yq/v4@latest
go install github.com/cloudflare/cfssl/cmd/cfssl@latest
go install github.com/cloudflare/cfssl/cmd/cfssljson@latest
sudo cp -u ~/go/bin/yq /usr/local/bin/ sudo cp -u ~/go/bin/yq /usr/local/bin/
sudo cp -u ~/go/bin/cfssl /usr/local/bin
sudo cp -u ~/go/bin/cfssljson /usr/local/bin
- name: "Build crowdsec and fixture" - name: "Build crowdsec and fixture"
run: make bats-clean bats-build bats-fixture run: make bats-clean bats-build bats-fixture

View file

@ -47,8 +47,12 @@ jobs:
- name: "Install bats dependencies" - name: "Install bats dependencies"
run: | run: |
sudo apt install -y -qq build-essential daemonize jq netcat-openbsd sudo apt install -y -qq build-essential daemonize jq netcat-openbsd
GO111MODULE=on go get github.com/mikefarah/yq/v4 go install github.com/mikefarah/yq/v4@latest
go install github.com/cloudflare/cfssl/cmd/cfssl@latest
go install github.com/cloudflare/cfssl/cmd/cfssljson@latest
sudo cp -u ~/go/bin/yq /usr/local/bin/ sudo cp -u ~/go/bin/yq /usr/local/bin/
sudo cp -u ~/go/bin/cfssl /usr/local/bin
sudo cp -u ~/go/bin/cfssljson /usr/local/bin
- name: "Build crowdsec and fixture (DB_BACKEND: pgx)" - name: "Build crowdsec and fixture (DB_BACKEND: pgx)"
run: make clean bats-build bats-fixture run: make clean bats-build bats-fixture

View file

@ -32,10 +32,11 @@ jobs:
- name: "Install bats dependencies" - name: "Install bats dependencies"
run: | run: |
sudo apt install -y -qq build-essential daemonize jq netcat-openbsd sudo apt install -y -qq build-essential daemonize jq netcat-openbsd
GO111MODULE=on go get github.com/mikefarah/yq/v4 go install github.com/mikefarah/yq/v4@latest
sudo cp -u ~/go/bin/yq /usr/local/bin/ go install github.com/cloudflare/cfssl/cmd/cfssl@latest
go install github.com/cloudflare/cfssl/cmd/cfssljson@latest
go install github.com/wadey/gocovmerge@latest go install github.com/wadey/gocovmerge@latest
sudo cp -u ~/go/bin/gocovmerge /usr/local/bin/ sudo cp -u ~/go/bin/yq ~/go/bin/gocovmerge ~/go/bin/cfssl ~/go/bin/cfssljson /usr/local/bin/
- name: "Build crowdsec and fixture" - name: "Build crowdsec and fixture"
run: TEST_COVERAGE=true make bats-clean bats-build bats-fixture run: TEST_COVERAGE=true make bats-clean bats-build bats-fixture

View file

@ -9,6 +9,7 @@ import (
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/enescakir/emoji" "github.com/enescakir/emoji"
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -65,7 +66,7 @@ Note: This command requires database direct access, so is intended to be run on
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
table.SetAlignment(tablewriter.ALIGN_LEFT) table.SetAlignment(tablewriter.ALIGN_LEFT)
table.SetHeader([]string{"Name", "IP Address", "Valid", "Last API pull", "Type", "Version"}) table.SetHeader([]string{"Name", "IP Address", "Valid", "Last API pull", "Type", "Version", "Auth Type"})
for _, b := range blockers { for _, b := range blockers {
var revoked string var revoked string
if !b.Revoked { if !b.Revoked {
@ -73,7 +74,7 @@ Note: This command requires database direct access, so is intended to be run on
} else { } else {
revoked = emoji.Prohibited.String() revoked = emoji.Prohibited.String()
} }
table.Append([]string{b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version}) table.Append([]string{b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version, b.AuthType})
} }
table.Render() table.Render()
} else if csConfig.Cscli.Output == "json" { } else if csConfig.Cscli.Output == "json" {
@ -84,7 +85,7 @@ Note: This command requires database direct access, so is intended to be run on
fmt.Printf("%s", string(x)) fmt.Printf("%s", string(x))
} else if csConfig.Cscli.Output == "raw" { } else if csConfig.Cscli.Output == "raw" {
csvwriter := csv.NewWriter(os.Stdout) csvwriter := csv.NewWriter(os.Stdout)
err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version"}) err := csvwriter.Write([]string{"name", "ip", "revoked", "last_pull", "type", "version", "auth_type"})
if err != nil { if err != nil {
log.Fatalf("failed to write raw header: %s", err) log.Fatalf("failed to write raw header: %s", err)
} }
@ -95,7 +96,7 @@ Note: This command requires database direct access, so is intended to be run on
} else { } else {
revoked = "pending" revoked = "pending"
} }
err := csvwriter.Write([]string{b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version}) err := csvwriter.Write([]string{b.Name, b.IPAddress, revoked, b.LastPull.Format(time.RFC3339), b.Type, b.Version, b.AuthType})
if err != nil { if err != nil {
log.Fatalf("failed to write raw: %s", err) log.Fatalf("failed to write raw: %s", err)
} }
@ -129,7 +130,7 @@ cscli bouncers add MyBouncerName -k %s`, generatePassword(32)),
if err != nil { if err != nil {
log.Fatalf("unable to generate api key: %s", err) log.Fatalf("unable to generate api key: %s", err)
} }
err = dbClient.CreateBouncer(keyName, keyIP, middlewares.HashSHA512(apiKey)) _, err = dbClient.CreateBouncer(keyName, keyIP, middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
if err != nil { if err != nil {
log.Fatalf("unable to create bouncer: %s", err) log.Fatalf("unable to create bouncer: %s", err)
} }

View file

@ -368,6 +368,29 @@ func NewConfigCmd() *cobra.Command {
if csConfig.API.Server.TLS.KeyFilePath != "" { if csConfig.API.Server.TLS.KeyFilePath != "" {
fmt.Printf(" - Key File : %s\n", csConfig.API.Server.TLS.KeyFilePath) fmt.Printf(" - Key File : %s\n", csConfig.API.Server.TLS.KeyFilePath)
} }
if csConfig.API.Server.TLS.CACertPath != "" {
fmt.Printf(" - CA Cert : %s\n", csConfig.API.Server.TLS.CACertPath)
}
if csConfig.API.Server.TLS.CRLPath != "" {
fmt.Printf(" - CRL : %s\n", csConfig.API.Server.TLS.CRLPath)
}
if csConfig.API.Server.TLS.CacheExpiration != nil {
fmt.Printf(" - Cache Expiration : %s\n", csConfig.API.Server.TLS.CacheExpiration)
}
if csConfig.API.Server.TLS.ClientVerification != "" {
fmt.Printf(" - Client Verification : %s\n", csConfig.API.Server.TLS.ClientVerification)
}
if csConfig.API.Server.TLS.AllowedAgentsOU != nil {
for _, ou := range csConfig.API.Server.TLS.AllowedAgentsOU {
fmt.Printf(" - Allowed Agents OU : %s\n", ou)
}
}
if csConfig.API.Server.TLS.AllowedBouncersOU != nil {
for _, ou := range csConfig.API.Server.TLS.AllowedBouncersOU {
fmt.Printf(" - Allowed Bouncers OU : %s\n", ou)
}
}
} }
fmt.Printf(" - Trusted IPs: \n") fmt.Printf(" - Trusted IPs: \n")
for _, ip := range csConfig.API.Server.TrustedIPs { for _, ip := range csConfig.API.Server.TrustedIPs {

View file

@ -14,6 +14,7 @@ import (
"github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2"
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/crowdsecurity/machineid" "github.com/crowdsecurity/machineid"
"github.com/enescakir/emoji" "github.com/enescakir/emoji"
"github.com/go-openapi/strfmt" "github.com/go-openapi/strfmt"
@ -140,7 +141,7 @@ Note: This command requires database direct access, so is intended to be run on
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
table.SetAlignment(tablewriter.ALIGN_LEFT) table.SetAlignment(tablewriter.ALIGN_LEFT)
table.SetHeader([]string{"Name", "IP Address", "Last Update", "Status", "Version", "Last Heartbeat"}) table.SetHeader([]string{"Name", "IP Address", "Last Update", "Status", "Version", "Auth Type", "Last Heartbeat"})
for _, w := range machines { for _, w := range machines {
var validated string var validated string
if w.IsValidated { if w.IsValidated {
@ -153,7 +154,7 @@ Note: This command requires database direct access, so is intended to be run on
if lastHeartBeat > 2*time.Minute { if lastHeartBeat > 2*time.Minute {
hbDisplay = fmt.Sprintf("%s %s", emoji.Warning.String(), lastHeartBeat.Truncate(time.Second).String()) hbDisplay = fmt.Sprintf("%s %s", emoji.Warning.String(), lastHeartBeat.Truncate(time.Second).String())
} }
table.Append([]string{w.MachineId, w.IpAddress, w.UpdatedAt.Format(time.RFC3339), validated, w.Version, hbDisplay}) table.Append([]string{w.MachineId, w.IpAddress, w.UpdatedAt.Format(time.RFC3339), validated, w.Version, w.AuthType, hbDisplay})
} }
table.Render() table.Render()
} else if csConfig.Cscli.Output == "json" { } else if csConfig.Cscli.Output == "json" {
@ -164,7 +165,7 @@ Note: This command requires database direct access, so is intended to be run on
fmt.Printf("%s", string(x)) fmt.Printf("%s", string(x))
} else if csConfig.Cscli.Output == "raw" { } else if csConfig.Cscli.Output == "raw" {
csvwriter := csv.NewWriter(os.Stdout) csvwriter := csv.NewWriter(os.Stdout)
err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "last_heartbeat"}) err := csvwriter.Write([]string{"machine_id", "ip_address", "updated_at", "validated", "version", "auth_type", "last_heartbeat"})
if err != nil { if err != nil {
log.Fatalf("failed to write header: %s", err) log.Fatalf("failed to write header: %s", err)
} }
@ -175,7 +176,7 @@ Note: This command requires database direct access, so is intended to be run on
} else { } else {
validated = "false" validated = "false"
} }
err := csvwriter.Write([]string{w.MachineId, w.IpAddress, w.UpdatedAt.Format(time.RFC3339), validated, w.Version, time.Now().UTC().Sub(*w.LastHeartbeat).Truncate(time.Second).String()}) err := csvwriter.Write([]string{w.MachineId, w.IpAddress, w.UpdatedAt.Format(time.RFC3339), validated, w.Version, w.AuthType, time.Now().UTC().Sub(*w.LastHeartbeat).Truncate(time.Second).String()})
if err != nil { if err != nil {
log.Fatalf("failed to write raw output : %s", err) log.Fatalf("failed to write raw output : %s", err)
} }
@ -244,7 +245,7 @@ cscli machines add MyTestMachine --password MyPassword
survey.AskOne(qs, &machinePassword) survey.AskOne(qs, &machinePassword)
} }
password := strfmt.Password(machinePassword) password := strfmt.Password(machinePassword)
_, err = dbClient.CreateMachine(&machineID, &password, "", true, forceAdd) _, err = dbClient.CreateMachine(&machineID, &password, "", true, forceAdd, types.PasswordAuthType)
if err != nil { if err != nil {
log.Fatalf("unable to create machine: %s", err) log.Fatalf("unable to create machine: %s", err)
} }

View file

@ -3,6 +3,7 @@ package apiclient
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -15,6 +16,8 @@ import (
var ( var (
InsecureSkipVerify = false InsecureSkipVerify = false
Cert *tls.Certificate
CaCertPool *x509.CertPool
) )
type ApiClient struct { type ApiClient struct {
@ -49,7 +52,12 @@ func NewClient(config *Config) (*ApiClient, error) {
VersionPrefix: config.VersionPrefix, VersionPrefix: config.VersionPrefix,
UpdateScenario: config.UpdateScenario, UpdateScenario: config.UpdateScenario,
} }
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
if Cert != nil {
tlsconfig.RootCAs = CaCertPool
tlsconfig.Certificates = []tls.Certificate{*Cert}
}
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig
c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} c := &ApiClient{client: t.Client(), BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix}
c.common.client = c c.common.client = c
c.Decisions = (*DecisionsService)(&c.common) c.Decisions = (*DecisionsService)(&c.common)
@ -66,7 +74,12 @@ func NewDefaultClient(URL *url.URL, prefix string, userAgent string, client *htt
if client == nil { if client == nil {
client = &http.Client{} client = &http.Client{}
if ht, ok := http.DefaultTransport.(*http.Transport); ok { if ht, ok := http.DefaultTransport.(*http.Transport); ok {
ht.TLSClientConfig = &tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
if Cert != nil {
tlsconfig.RootCAs = CaCertPool
tlsconfig.Certificates = []tls.Certificate{*Cert}
}
ht.TLSClientConfig = &tlsconfig
client.Transport = ht client.Transport = ht
} }
} }
@ -86,7 +99,12 @@ func RegisterClient(config *Config, client *http.Client) (*ApiClient, error) {
if client == nil { if client == nil {
client = &http.Client{} client = &http.Client{}
} }
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: InsecureSkipVerify} tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
if Cert != nil {
tlsconfig.RootCAs = CaCertPool
tlsconfig.Certificates = []tls.Certificate{*Cert}
}
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tlsconfig
c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix} c := &ApiClient{client: client, BaseURL: config.URL, UserAgent: config.UserAgent, URLPrefix: config.VersionPrefix}
c.common.client = c c.common.client = c
c.Decisions = (*DecisionsService)(&c.common) c.Decisions = (*DecisionsService)(&c.common)

View file

@ -45,17 +45,22 @@ func SetupLAPITest(t *testing.T) LAPI {
func (l *LAPI) InsertAlertFromFile(path string) *httptest.ResponseRecorder { func (l *LAPI) InsertAlertFromFile(path string) *httptest.ResponseRecorder {
alertReader := GetAlertReaderFromFile(path) alertReader := GetAlertReaderFromFile(path)
return l.RecordResponse("POST", "/v1/alerts", alertReader) return l.RecordResponse("POST", "/v1/alerts", alertReader, "password")
} }
func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader) *httptest.ResponseRecorder { func (l *LAPI) RecordResponse(verb string, url string, body *strings.Reader, authType string) *httptest.ResponseRecorder {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, err := http.NewRequest(verb, url, body) req, err := http.NewRequest(verb, url, body)
if err != nil { if err != nil {
l.t.Fatal(err) l.t.Fatal(err)
} }
req.Header.Add("X-Api-Key", l.bouncerKey) if authType == "apikey" {
AddAuthHeaders(req, l.loginResp) req.Header.Add("X-Api-Key", l.bouncerKey)
} else if authType == "password" {
AddAuthHeaders(req, l.loginResp)
} else {
l.t.Fatal("auth type not supported")
}
l.router.ServeHTTP(w, req) l.router.ServeHTTP(w, req)
return w return w
} }
@ -93,6 +98,7 @@ func LoginToTestAPI(router *gin.Engine, config csconfig.Config) (models.WatcherA
if err != nil { if err != nil {
return models.WatcherAuthResponse{}, fmt.Errorf("%s", err.Error()) return models.WatcherAuthResponse{}, fmt.Errorf("%s", err.Error())
} }
return loginResp, nil return loginResp, nil
} }
@ -107,13 +113,13 @@ func TestSimulatedAlert(t *testing.T) {
alertContent := GetAlertReaderFromFile("./tests/alert_minibulk+simul.json") alertContent := GetAlertReaderFromFile("./tests/alert_minibulk+simul.json")
//exclude decision in simulation mode //exclude decision in simulation mode
w := lapi.RecordResponse("GET", "/v1/alerts?simulated=false", alertContent) w := lapi.RecordResponse("GET", "/v1/alerts?simulated=false", alertContent, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `)
assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) assert.NotContains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `)
//include decision in simulation mode //include decision in simulation mode
w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", alertContent) w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", alertContent, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.178 performed crowdsecurity/ssh-bf (6 events over `)
assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `) assert.Contains(t, w.Body.String(), `"message":"Ip 91.121.79.179 performed crowdsecurity/ssh-bf (6 events over `)
@ -123,14 +129,14 @@ func TestCreateAlert(t *testing.T) {
lapi := SetupLAPITest(t) lapi := SetupLAPITest(t)
// Create Alert with invalid format // Create Alert with invalid format
w := lapi.RecordResponse("POST", "/v1/alerts", strings.NewReader("test")) w := lapi.RecordResponse("POST", "/v1/alerts", strings.NewReader("test"), "password")
assert.Equal(t, 400, w.Code) assert.Equal(t, 400, w.Code)
assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String()) assert.Equal(t, "{\"message\":\"invalid character 'e' in literal true (expecting 'r')\"}", w.Body.String())
// Create Alert with invalid input // Create Alert with invalid input
alertContent := GetAlertReaderFromFile("./tests/invalidAlert_sample.json") alertContent := GetAlertReaderFromFile("./tests/invalidAlert_sample.json")
w = lapi.RecordResponse("POST", "/v1/alerts", alertContent) w = lapi.RecordResponse("POST", "/v1/alerts", alertContent, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"validation failure list:\\n0.scenario in body is required\\n0.scenario_hash in body is required\\n0.scenario_version in body is required\\n0.simulated in body is required\\n0.source in body is required\"}", w.Body.String()) assert.Equal(t, "{\"message\":\"validation failure list:\\n0.scenario in body is required\\n0.scenario_hash in body is required\\n0.scenario_version in body is required\\n0.simulated in body is required\\n0.source in body is required\"}", w.Body.String())
@ -177,13 +183,13 @@ func TestAlertListFilters(t *testing.T) {
//bad filter //bad filter
w := lapi.RecordResponse("GET", "/v1/alerts?test=test", alertContent) w := lapi.RecordResponse("GET", "/v1/alerts?test=test", alertContent, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String())
//get without filters //get without filters
w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
//check alert and decision //check alert and decision
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
@ -191,149 +197,149 @@ func TestAlertListFilters(t *testing.T) {
//test decision_type filter (ok) //test decision_type filter (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ban", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ban", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test decision_type filter (bad value) //test decision_type filter (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ratata", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?decision_type=ratata", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String()) assert.Equal(t, "null", w.Body.String())
//test scope (ok) //test scope (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?scope=Ip", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?scope=Ip", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test scope (bad value) //test scope (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?scope=rarara", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?scope=rarara", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String()) assert.Equal(t, "null", w.Body.String())
//test scenario (ok) //test scenario (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/ssh-bf", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test scenario (bad value) //test scenario (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?scenario=crowdsecurity/nope", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String()) assert.Equal(t, "null", w.Body.String())
//test ip (ok) //test ip (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?ip=91.121.79.195", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?ip=91.121.79.195", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test ip (bad value) //test ip (bad value)
w = lapi.RecordResponse("GET", "/v1/alerts?ip=99.122.77.195", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?ip=99.122.77.195", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String()) assert.Equal(t, "null", w.Body.String())
//test ip (invalid value) //test ip (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?ip=gruueq", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?ip=gruueq", emptyBody, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String()) assert.Equal(t, `{"message":"unable to convert 'gruueq' to int: invalid address: invalid ip address / range"}`, w.Body.String())
//test range (ok) //test range (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?range=91.121.79.0/24&contains=false", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test range //test range
w = lapi.RecordResponse("GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?range=99.122.77.0/24&contains=false", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String()) assert.Equal(t, "null", w.Body.String())
//test range (invalid value) //test range (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?range=ratata", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?range=ratata", emptyBody, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String()) assert.Equal(t, `{"message":"unable to convert 'ratata' to int: invalid address: invalid ip address / range"}`, w.Body.String())
//test since (ok) //test since (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?since=1h", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?since=1h", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test since (ok but yields no results) //test since (ok but yields no results)
w = lapi.RecordResponse("GET", "/v1/alerts?since=1ns", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?since=1ns", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String()) assert.Equal(t, "null", w.Body.String())
//test since (invalid value) //test since (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?since=1zuzu", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?since=1zuzu", emptyBody, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`)
//test until (ok) //test until (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?until=1ns", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?until=1ns", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test until (ok but no return) //test until (ok but no return)
w = lapi.RecordResponse("GET", "/v1/alerts?until=1m", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?until=1m", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String()) assert.Equal(t, "null", w.Body.String())
//test until (invalid value) //test until (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?until=1zuzu", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?until=1zuzu", emptyBody, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`) assert.Contains(t, w.Body.String(), `{"message":"while parsing duration: time: unknown unit`)
//test simulated (ok) //test simulated (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?simulated=true", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test simulated (ok) //test simulated (ok)
w = lapi.RecordResponse("GET", "/v1/alerts?simulated=false", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?simulated=false", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test has active decision //test has active decision
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=true", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=true", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ") assert.Contains(t, w.Body.String(), "Ip 91.121.79.195 performed 'crowdsecurity/ssh-bf' (6 events over ")
assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`) assert.Contains(t, w.Body.String(), `scope":"Ip","simulated":false,"type":"ban","value":"91.121.79.195"`)
//test has active decision //test has active decision
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=false", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=false", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, "null", w.Body.String()) assert.Equal(t, "null", w.Body.String())
//test has active decision (invalid value) //test has active decision (invalid value)
w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts?has_active_decision=ratatqata", emptyBody, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String()) assert.Equal(t, `{"message":"'ratatqata' is not a boolean: strconv.ParseBool: parsing \"ratatqata\": invalid syntax: unable to parse type"}`, w.Body.String())
@ -345,7 +351,7 @@ func TestAlertBulkInsert(t *testing.T) {
lapi.InsertAlertFromFile("./tests/alert_bulk.json") lapi.InsertAlertFromFile("./tests/alert_bulk.json")
alertContent := GetAlertReaderFromFile("./tests/alert_bulk.json") alertContent := GetAlertReaderFromFile("./tests/alert_bulk.json")
w := lapi.RecordResponse("GET", "/v1/alerts", alertContent) w := lapi.RecordResponse("GET", "/v1/alerts", alertContent, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
} }
@ -354,13 +360,13 @@ func TestListAlert(t *testing.T) {
lapi.InsertAlertFromFile("./tests/alert_sample.json") lapi.InsertAlertFromFile("./tests/alert_sample.json")
// List Alert with invalid filter // List Alert with invalid filter
w := lapi.RecordResponse("GET", "/v1/alerts?test=test", emptyBody) w := lapi.RecordResponse("GET", "/v1/alerts?test=test", emptyBody, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String()) assert.Equal(t, "{\"message\":\"Filter parameter 'test' is unknown (=test): invalid filter\"}", w.Body.String())
// List Alert // List Alert
w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody) w = lapi.RecordResponse("GET", "/v1/alerts", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "crowdsecurity/test") assert.Contains(t, w.Body.String(), "crowdsecurity/test")
} }

View file

@ -2,8 +2,11 @@ package apiserver
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"os" "os"
@ -11,6 +14,7 @@ import (
"time" "time"
"github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers" "github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers"
v1 "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/csplugin"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
@ -240,12 +244,56 @@ func (s *APIServer) Router() (*gin.Engine, error) {
return s.router, nil return s.router, nil
} }
func (s *APIServer) GetTLSConfig() (*tls.Config, error) {
var caCert []byte
var err error
var caCertPool *x509.CertPool
var clientAuthType tls.ClientAuthType
if s.TLS == nil {
return &tls.Config{}, nil
}
if s.TLS.ClientVerification == "" {
//sounds like a sane default : verify client cert if given, but don't make it mandatory
clientAuthType = tls.VerifyClientCertIfGiven
} else {
clientAuthType, err = getTLSAuthType(s.TLS.ClientVerification)
if err != nil {
return nil, err
}
}
if s.TLS.CACertPath != "" {
if clientAuthType > tls.RequestClientCert {
log.Infof("(tls) Client Auth Type set to %s", clientAuthType.String())
caCert, err = ioutil.ReadFile(s.TLS.CACertPath)
if err != nil {
return nil, errors.Wrap(err, "Error opening cert file")
}
caCertPool = x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
}
}
return &tls.Config{
ServerName: s.TLS.ServerName, //should it be removed ?
ClientAuth: clientAuthType,
ClientCAs: caCertPool,
MinVersion: tls.VersionTLS12, // TLS versions below 1.2 are considered insecure - see https://www.rfc-editor.org/rfc/rfc7525.txt for details
}, nil
}
func (s *APIServer) Run() error { func (s *APIServer) Run() error {
defer types.CatchPanic("lapi/runServer") defer types.CatchPanic("lapi/runServer")
tlsCfg, err := s.GetTLSConfig()
if err != nil {
return errors.Wrap(err, "while creating TLS config")
}
s.httpServer = &http.Server{ s.httpServer = &http.Server{
Addr: s.URL, Addr: s.URL,
Handler: s.router, Handler: s.router,
TLSConfig: tlsCfg,
} }
if s.apic != nil { if s.apic != nil {
@ -326,6 +374,36 @@ func (s *APIServer) AttachPluginBroker(broker *csplugin.PluginBroker) {
} }
func (s *APIServer) InitController() error { func (s *APIServer) InitController() error {
err := s.controller.Init() err := s.controller.Init()
if err != nil {
return errors.Wrap(err, "controller init")
}
if s.TLS != nil {
var cacheExpiration time.Duration
if s.TLS.CacheExpiration != nil {
cacheExpiration = *s.TLS.CacheExpiration
} else {
cacheExpiration = time.Hour
}
s.controller.HandlerV1.Middlewares.JWT.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedAgentsOU, s.TLS.CRLPath,
cacheExpiration,
log.WithFields(log.Fields{
"component": "tls-auth",
"type": "agent",
}))
if err != nil {
return errors.Wrap(err, "while creating TLS auth for agents")
}
s.controller.HandlerV1.Middlewares.APIKey.TlsAuth, err = v1.NewTLSAuth(s.TLS.AllowedBouncersOU, s.TLS.CRLPath,
cacheExpiration,
log.WithFields(log.Fields{
"component": "tls-auth",
"type": "bouncer",
}))
if err != nil {
return errors.Wrap(err, "while creating TLS auth for bouncers")
}
}
return err return err
} }

View file

@ -275,7 +275,7 @@ func CreateTestBouncer(config *csconfig.DatabaseCfg) (string, error) {
if err != nil { if err != nil {
return "", fmt.Errorf("unable to generate api key: %s", err) return "", fmt.Errorf("unable to generate api key: %s", err)
} }
err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey)) _, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to create blocker: %s", err) return "", fmt.Errorf("unable to create blocker: %s", err)
} }

View file

@ -25,6 +25,7 @@ type Controller struct {
Log *log.Logger Log *log.Logger
ConsoleConfig *csconfig.ConsoleConfig ConsoleConfig *csconfig.ConsoleConfig
TrustedIPs []net.IPNet TrustedIPs []net.IPNet
HandlerV1 *v1.Controller
} }
func (c *Controller) Init() error { func (c *Controller) Init() error {
@ -55,12 +56,22 @@ func serveHealth() http.HandlerFunc {
} }
func (c *Controller) NewV1() error { func (c *Controller) NewV1() error {
var err error
handlerV1, err := v1.New(c.DBClient, c.Ectx, c.Profiles, c.CAPIChan, c.PluginChannel, *c.ConsoleConfig, c.TrustedIPs) v1Config := v1.ControllerV1Config{
DbClient: c.DBClient,
Ctx: c.Ectx,
Profiles: c.Profiles,
CapiChan: c.CAPIChan,
PluginChannel: c.PluginChannel,
ConsoleConfig: *c.ConsoleConfig,
TrustedIPs: c.TrustedIPs,
}
c.HandlerV1, err = v1.New(&v1Config)
if err != nil { if err != nil {
return err return err
} }
c.Router.GET("/health", gin.WrapF(serveHealth())) c.Router.GET("/health", gin.WrapF(serveHealth()))
c.Router.Use(v1.PrometheusMiddleware()) c.Router.Use(v1.PrometheusMiddleware())
c.Router.HandleMethodNotAllowed = true c.Router.HandleMethodNotAllowed = true
@ -72,31 +83,32 @@ func (c *Controller) NewV1() error {
}) })
groupV1 := c.Router.Group("/v1") groupV1 := c.Router.Group("/v1")
groupV1.POST("/watchers", handlerV1.CreateMachine) groupV1.POST("/watchers", c.HandlerV1.CreateMachine)
groupV1.POST("/watchers/login", handlerV1.Middlewares.JWT.Middleware.LoginHandler) groupV1.POST("/watchers/login", c.HandlerV1.Middlewares.JWT.Middleware.LoginHandler)
jwtAuth := groupV1.Group("") jwtAuth := groupV1.Group("")
jwtAuth.GET("/refresh_token", handlerV1.Middlewares.JWT.Middleware.RefreshHandler) jwtAuth.GET("/refresh_token", c.HandlerV1.Middlewares.JWT.Middleware.RefreshHandler)
jwtAuth.Use(handlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware()) jwtAuth.Use(c.HandlerV1.Middlewares.JWT.Middleware.MiddlewareFunc(), v1.PrometheusMachinesMiddleware())
{ {
jwtAuth.POST("/alerts", handlerV1.CreateAlert) jwtAuth.POST("/alerts", c.HandlerV1.CreateAlert)
jwtAuth.GET("/alerts", handlerV1.FindAlerts) jwtAuth.GET("/alerts", c.HandlerV1.FindAlerts)
jwtAuth.HEAD("/alerts", handlerV1.FindAlerts) jwtAuth.HEAD("/alerts", c.HandlerV1.FindAlerts)
jwtAuth.GET("/alerts/:alert_id", handlerV1.FindAlertByID) jwtAuth.GET("/alerts/:alert_id", c.HandlerV1.FindAlertByID)
jwtAuth.HEAD("/alerts/:alert_id", handlerV1.FindAlertByID) jwtAuth.HEAD("/alerts/:alert_id", c.HandlerV1.FindAlertByID)
jwtAuth.DELETE("/alerts", handlerV1.DeleteAlerts) jwtAuth.DELETE("/alerts", c.HandlerV1.DeleteAlerts)
jwtAuth.DELETE("/decisions", handlerV1.DeleteDecisions) jwtAuth.DELETE("/decisions", c.HandlerV1.DeleteDecisions)
jwtAuth.DELETE("/decisions/:decision_id", handlerV1.DeleteDecisionById) jwtAuth.DELETE("/decisions/:decision_id", c.HandlerV1.DeleteDecisionById)
jwtAuth.GET("/heartbeat", handlerV1.HeartBeat) jwtAuth.GET("/heartbeat", c.HandlerV1.HeartBeat)
} }
apiKeyAuth := groupV1.Group("") apiKeyAuth := groupV1.Group("")
apiKeyAuth.Use(handlerV1.Middlewares.APIKey.MiddlewareFunc(), v1.PrometheusBouncersMiddleware()) apiKeyAuth.Use(c.HandlerV1.Middlewares.APIKey.MiddlewareFunc(), v1.PrometheusBouncersMiddleware())
{ {
apiKeyAuth.GET("/decisions", handlerV1.GetDecision) apiKeyAuth.GET("/decisions", c.HandlerV1.GetDecision)
apiKeyAuth.HEAD("/decisions", handlerV1.GetDecision) apiKeyAuth.HEAD("/decisions", c.HandlerV1.GetDecision)
apiKeyAuth.GET("/decisions/stream", handlerV1.StreamDecision) apiKeyAuth.GET("/decisions/stream", c.HandlerV1.StreamDecision)
apiKeyAuth.HEAD("/decisions/stream", handlerV1.StreamDecision) apiKeyAuth.HEAD("/decisions/stream", c.HandlerV1.StreamDecision)
} }
return nil return nil

View file

@ -4,6 +4,8 @@ import (
"context" "context"
"net" "net"
//"github.com/crowdsecurity/crowdsec/pkg/apiserver/controllers"
middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1" middlewares "github.com/crowdsecurity/crowdsec/pkg/apiserver/middlewares/v1"
"github.com/crowdsecurity/crowdsec/pkg/csconfig" "github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/csplugin" "github.com/crowdsecurity/crowdsec/pkg/csplugin"
@ -23,19 +25,29 @@ type Controller struct {
TrustedIPs []net.IPNet TrustedIPs []net.IPNet
} }
func New(dbClient *database.Client, ctx context.Context, profiles []*csconfig.ProfileCfg, capiChan chan []*models.Alert, pluginChannel chan csplugin.ProfileAlert, consoleConfig csconfig.ConsoleConfig, trustedIPs []net.IPNet) (*Controller, error) { type ControllerV1Config struct {
DbClient *database.Client
Ctx context.Context
Profiles []*csconfig.ProfileCfg
CapiChan chan []*models.Alert
PluginChannel chan csplugin.ProfileAlert
ConsoleConfig csconfig.ConsoleConfig
TrustedIPs []net.IPNet
}
func New(cfg *ControllerV1Config) (*Controller, error) {
var err error var err error
v1 := &Controller{ v1 := &Controller{
Ectx: ctx, Ectx: cfg.Ctx,
DBClient: dbClient, DBClient: cfg.DbClient,
APIKeyHeader: middlewares.APIKeyHeader, APIKeyHeader: middlewares.APIKeyHeader,
Profiles: profiles, Profiles: cfg.Profiles,
CAPIChan: capiChan, CAPIChan: cfg.CapiChan,
PluginChannel: pluginChannel, PluginChannel: cfg.PluginChannel,
ConsoleConfig: consoleConfig, ConsoleConfig: cfg.ConsoleConfig,
TrustedIPs: trustedIPs, TrustedIPs: cfg.TrustedIPs,
} }
v1.Middlewares, err = middlewares.NewMiddlewares(dbClient) v1.Middlewares, err = middlewares.NewMiddlewares(cfg.DbClient)
if err != nil { if err != nil {
return v1, err return v1, err
} }

View file

@ -4,6 +4,7 @@ import (
"net/http" "net/http"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt" "github.com/go-openapi/strfmt"
) )
@ -20,7 +21,7 @@ func (c *Controller) CreateMachine(gctx *gin.Context) {
return return
} }
_, err = c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false) _, err = c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType)
if err != nil { if err != nil {
c.HandleDBErrors(gctx, err) c.HandleDBErrors(gctx, err)
return return

View file

@ -14,20 +14,20 @@ func TestDeleteDecisionRange(t *testing.T) {
lapi.InsertAlertFromFile("./tests/alert_minibulk.json") lapi.InsertAlertFromFile("./tests/alert_minibulk.json")
// delete by ip wrong // delete by ip wrong
w := lapi.RecordResponse("DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody) w := lapi.RecordResponse("DELETE", "/v1/decisions?range=1.2.3.0/24", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by range // delete by range
w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String()) assert.Equal(t, `{"nbDeleted":"2"}`, w.Body.String())
// delete by range : ensure it was already deleted // delete by range : ensure it was already deleted
w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions?range=91.121.79.0/24", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
} }
@ -40,19 +40,19 @@ func TestDeleteDecisionFilter(t *testing.T) {
// delete by ip wrong // delete by ip wrong
w := lapi.RecordResponse("DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody) w := lapi.RecordResponse("DELETE", "/v1/decisions?ip=1.2.3.4", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String()) assert.Equal(t, `{"nbDeleted":"0"}`, w.Body.String())
// delete by ip good // delete by ip good
w = lapi.RecordResponse("DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions?ip=91.121.79.179", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
// delete by scope/value // delete by scope/value
w = lapi.RecordResponse("DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions?scopes=Ip&value=91.121.79.178", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String()) assert.Equal(t, `{"nbDeleted":"1"}`, w.Body.String())
} }
@ -65,7 +65,7 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision // Get Decision
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody) w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, "apikey")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err := readDecisionsGetResp(w) decisions, code, err := readDecisionsGetResp(w)
assert.Nil(t, err) assert.Nil(t, err)
@ -80,7 +80,7 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision : type filter // Get Decision : type filter
w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions?type=ban", emptyBody, "apikey")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w) decisions, code, err = readDecisionsGetResp(w)
assert.Nil(t, err) assert.Nil(t, err)
@ -98,7 +98,7 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision : scope/value // Get Decision : scope/value
w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions?scopes=Ip&value=91.121.79.179", emptyBody, "apikey")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w) decisions, code, err = readDecisionsGetResp(w)
assert.Nil(t, err) assert.Nil(t, err)
@ -113,7 +113,7 @@ func TestGetDecisionFilters(t *testing.T) {
// Get Decision : ip filter // Get Decision : ip filter
w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions?ip=91.121.79.179", emptyBody, "apikey")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w) decisions, code, err = readDecisionsGetResp(w)
assert.Nil(t, err) assert.Nil(t, err)
@ -127,7 +127,7 @@ func TestGetDecisionFilters(t *testing.T) {
// assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`) // assert.NotContains(t, w.Body.String(), `"id":2,"origin":"crowdsec","scenario":"crowdsecurity/ssh-bf","scope":"Ip","type":"ban","value":"91.121.79.178"`)
// Get decision : by range // Get decision : by range
w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions?range=91.121.79.0/24&contains=false", emptyBody, "apikey")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err = readDecisionsGetResp(w) decisions, code, err = readDecisionsGetResp(w)
assert.Nil(t, err) assert.Nil(t, err)
@ -145,7 +145,7 @@ func TestGetDecision(t *testing.T) {
lapi.InsertAlertFromFile("./tests/alert_sample.json") lapi.InsertAlertFromFile("./tests/alert_sample.json")
// Get Decision // Get Decision
w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody) w := lapi.RecordResponse("GET", "/v1/decisions", emptyBody, "apikey")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
decisions, code, err := readDecisionsGetResp(w) decisions, code, err := readDecisionsGetResp(w)
assert.Nil(t, err) assert.Nil(t, err)
@ -165,7 +165,7 @@ func TestGetDecision(t *testing.T) {
assert.Equal(t, int64(3), decisions[2].ID) assert.Equal(t, int64(3), decisions[2].ID)
// Get Decision with invalid filter. It should ignore this filter // Get Decision with invalid filter. It should ignore this filter
w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions?test=test", emptyBody, "apikey")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
assert.Equal(t, 3, len(decisions)) assert.Equal(t, 3, len(decisions))
} }
@ -177,7 +177,7 @@ func TestDeleteDecisionByID(t *testing.T) {
lapi.InsertAlertFromFile("./tests/alert_sample.json") lapi.InsertAlertFromFile("./tests/alert_sample.json")
//Have one alerts //Have one alerts
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey")
decisions, code, err := readDecisionsStreamResp(w) decisions, code, err := readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -185,21 +185,21 @@ func TestDeleteDecisionByID(t *testing.T) {
assert.Equal(t, len(decisions["new"]), 1) assert.Equal(t, len(decisions["new"]), 1)
// Delete alert with Invalid ID // Delete alert with Invalid ID
w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions/test", emptyBody, "password")
assert.Equal(t, 400, w.Code) assert.Equal(t, 400, w.Code)
err_resp, _, err := readDecisionsErrorResp(w) err_resp, _, err := readDecisionsErrorResp(w)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, err_resp["message"], "decision_id must be valid integer") assert.Equal(t, err_resp["message"], "decision_id must be valid integer")
// Delete alert with ID that not exist // Delete alert with ID that not exist
w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions/100", emptyBody, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
err_resp, _, err = readDecisionsErrorResp(w) err_resp, _, err = readDecisionsErrorResp(w)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, err_resp["message"], "decision with id '100' doesn't exist: unable to delete") assert.Equal(t, err_resp["message"], "decision with id '100' doesn't exist: unable to delete")
//Have one alerts //Have one alerts
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -207,14 +207,14 @@ func TestDeleteDecisionByID(t *testing.T) {
assert.Equal(t, len(decisions["new"]), 1) assert.Equal(t, len(decisions["new"]), 1)
// Delete alert with valid ID // Delete alert with valid ID
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
resp, _, err := readDecisionsDeleteResp(w) resp, _, err := readDecisionsDeleteResp(w)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, resp.NbDeleted, "1") assert.Equal(t, resp.NbDeleted, "1")
//Have one alert (because we delete an alert that has dup targets) //Have one alert (because we delete an alert that has dup targets)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -229,14 +229,14 @@ func TestDeleteDecision(t *testing.T) {
lapi.InsertAlertFromFile("./tests/alert_sample.json") lapi.InsertAlertFromFile("./tests/alert_sample.json")
// Delete alert with Invalid filter // Delete alert with Invalid filter
w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody) w := lapi.RecordResponse("DELETE", "/v1/decisions?test=test", emptyBody, "password")
assert.Equal(t, 500, w.Code) assert.Equal(t, 500, w.Code)
err_resp, _, err := readDecisionsErrorResp(w) err_resp, _, err := readDecisionsErrorResp(w)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, err_resp["message"], "'test' doesn't exist: invalid filter") assert.Equal(t, err_resp["message"], "'test' doesn't exist: invalid filter")
// Delete all alert // Delete all alert
w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
resp, _, err := readDecisionsDeleteResp(w) resp, _, err := readDecisionsDeleteResp(w)
assert.NoError(t, err) assert.NoError(t, err)
@ -251,7 +251,7 @@ func TestStreamStartDecisionDedup(t *testing.T) {
lapi.InsertAlertFromFile("./tests/alert_sample.json") lapi.InsertAlertFromFile("./tests/alert_sample.json")
// Get Stream, we only get one decision (the longest one) // Get Stream, we only get one decision (the longest one)
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey")
decisions, code, err := readDecisionsStreamResp(w) decisions, code, err := readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -262,11 +262,11 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1") assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1")
// id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip // id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip
w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
// Get Stream, we only get one decision (the longest one, id=2) // Get Stream, we only get one decision (the longest one, id=2)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -277,11 +277,11 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1") assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1")
// We delete another decision, yet don't receive it in stream, since there's another decision on same IP // We delete another decision, yet don't receive it in stream, since there's another decision on same IP
w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
// And get the remaining decision (1) // And get the remaining decision (1)
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -292,11 +292,11 @@ func TestStreamStartDecisionDedup(t *testing.T) {
assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1") assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1")
// We delete the last decision, we receive the delete order // We delete the last decision, we receive the delete order
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
//and now we only get a deleted decision //and now we only get a deleted decision
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -317,7 +317,7 @@ func TestStreamDecisionDedup(t *testing.T) {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
// Get Stream, we only get one decision (the longest one) // Get Stream, we only get one decision (the longest one)
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey")
decisions, code, err := readDecisionsStreamResp(w) decisions, code, err := readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -328,10 +328,10 @@ func TestStreamDecisionDedup(t *testing.T) {
assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1") assert.Equal(t, *decisions["new"][0].Value, "127.0.0.1")
// id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip // id=3 decision is deleted, this won't affect `deleted`, because there are decisions on the same ip
w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions/3", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody, "apikey")
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
@ -339,10 +339,10 @@ func TestStreamDecisionDedup(t *testing.T) {
assert.Equal(t, len(decisions["deleted"]), 0) assert.Equal(t, len(decisions["deleted"]), 0)
assert.Equal(t, len(decisions["new"]), 0) assert.Equal(t, len(decisions["new"]), 0)
// We delete another decision, yet don't receive it in stream, since there's another decision on same IP // We delete another decision, yet don't receive it in stream, since there's another decision on same IP
w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions/2", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -350,10 +350,10 @@ func TestStreamDecisionDedup(t *testing.T) {
assert.Equal(t, len(decisions["new"]), 0) assert.Equal(t, len(decisions["new"]), 0)
// We delete the last decision, we receive the delete order // We delete the last decision, we receive the delete order
w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody) w = lapi.RecordResponse("DELETE", "/v1/decisions/1", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, code, 200) assert.Equal(t, code, 200)
@ -371,7 +371,7 @@ func TestStreamDecisionFilters(t *testing.T) {
// Create Valid Alert // Create Valid Alert
lapi.InsertAlertFromFile("./tests/alert_stream_fixture.json") lapi.InsertAlertFromFile("./tests/alert_stream_fixture.json")
w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody) w := lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true", emptyBody, "apikey")
decisions, code, err := readDecisionsStreamResp(w) decisions, code, err := readDecisionsStreamResp(w)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
@ -392,7 +392,7 @@ func TestStreamDecisionFilters(t *testing.T) {
assert.Equal(t, *decisions["new"][2].Scenario, "crowdsecurity/ddos") assert.Equal(t, *decisions["new"][2].Scenario, "crowdsecurity/ddos")
// test filter scenarios_not_containing // test filter scenarios_not_containing
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_not_containing=http", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_not_containing=http", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
@ -402,7 +402,7 @@ func TestStreamDecisionFilters(t *testing.T) {
assert.Equal(t, decisions["new"][1].ID, int64(3)) assert.Equal(t, decisions["new"][1].ID, int64(3))
// test filter scenarios_containing // test filter scenarios_containing
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_containing=http", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_containing=http", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
@ -411,7 +411,7 @@ func TestStreamDecisionFilters(t *testing.T) {
assert.Equal(t, decisions["new"][0].ID, int64(1)) assert.Equal(t, decisions["new"][0].ID, int64(1))
// test filters both by scenarios_not_containing and scenarios_containing // test filters both by scenarios_not_containing and scenarios_containing
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh&scenarios_containing=ddos", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&scenarios_not_containing=ssh&scenarios_containing=ddos", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
@ -420,7 +420,7 @@ func TestStreamDecisionFilters(t *testing.T) {
assert.Equal(t, decisions["new"][0].ID, int64(3)) assert.Equal(t, decisions["new"][0].ID, int64(3))
// test filter by origin // test filter by origin
w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&origins=test1,test2", emptyBody) w = lapi.RecordResponse("GET", "/v1/decisions/stream?startup=true&origins=test1,test2", emptyBody, "apikey")
decisions, code, err = readDecisionsStreamResp(w) decisions, code, err = readDecisionsStreamResp(w)
assert.Equal(t, err, nil) assert.Equal(t, err, nil)
assert.Equal(t, 200, code) assert.Equal(t, 200, code)

View file

@ -9,9 +9,9 @@ import (
func TestHeartBeat(t *testing.T) { func TestHeartBeat(t *testing.T) {
lapi := SetupLAPITest(t) lapi := SetupLAPITest(t)
w := lapi.RecordResponse("GET", "/v1/heartbeat", emptyBody) w := lapi.RecordResponse("GET", "/v1/heartbeat", emptyBody, "password")
assert.Equal(t, 200, w.Code) assert.Equal(t, 200, w.Code)
w = lapi.RecordResponse("POST", "/v1/heartbeat", emptyBody) w = lapi.RecordResponse("POST", "/v1/heartbeat", emptyBody, "password")
assert.Equal(t, 405, w.Code) assert.Equal(t, 405, w.Code)
} }

View file

@ -9,6 +9,8 @@ import (
"strings" "strings"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -21,6 +23,7 @@ var (
type APIKey struct { type APIKey struct {
HeaderName string HeaderName string
DbClient *database.Client DbClient *database.Client
TlsAuth *TLSAuth
} }
func GenerateAPIKey(n int) (string, error) { func GenerateAPIKey(n int) (string, error) {
@ -35,6 +38,7 @@ func NewAPIKey(dbClient *database.Client) *APIKey {
return &APIKey{ return &APIKey{
HeaderName: APIKeyHeader, HeaderName: APIKeyHeader,
DbClient: dbClient, DbClient: dbClient,
TlsAuth: &TLSAuth{},
} }
} }
@ -49,34 +53,132 @@ func HashSHA512(str string) string {
func (a *APIKey) MiddlewareFunc() gin.HandlerFunc { func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
val, ok := c.Request.Header[APIKeyHeader] var bouncer *ent.Bouncer
if !ok { var err error
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
hashStr := HashSHA512(val[0]) if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
bouncer, err := a.DbClient.SelectBouncer(hashStr) if a.TlsAuth == nil {
if err != nil { log.WithField("ip", c.ClientIP()).Error("TLS Auth is not configured but client presented a certificate")
log.Errorf("auth api key error: %s", err) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.Abort()
c.Abort() return
return }
validCert, extractedCN, err := a.TlsAuth.ValidateCert(c)
if !validCert {
log.WithField("ip", c.ClientIP()).Errorf("invalid client certificate: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
if err != nil {
log.WithField("ip", c.ClientIP()).Error(err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
bouncer, err = a.DbClient.SelectBouncerByName(bouncerName)
//This is likely not the proper way, but isNotFound does not seem to work
if err != nil && strings.Contains(err.Error(), "bouncer not found") {
//Because we have a valid cert, automatically create the bouncer in the database if it does not exist
//Set a random API key, but it will never be used
apiKey, err := GenerateAPIKey(64)
if err != nil {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("error generating mock api key: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Infof("Creating bouncer %s", bouncerName)
bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType)
if err != nil {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("creating bouncer db entry : %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
} else if err != nil {
//error while selecting bouncer
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("while selecting bouncers: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
} else {
//bouncer was found in DB
if bouncer.AuthType != types.TlsAuthType {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("bouncer isn't allowed to auth by TLS")
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
}
} else {
//API Key Authentication
val, ok := c.Request.Header[APIKeyHeader]
if !ok {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
}).Errorf("API key not found")
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
hashStr := HashSHA512(val[0])
bouncer, err = a.DbClient.SelectBouncer(hashStr)
if err != nil {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
}).Errorf("while fetching bouncer info: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
if bouncer.AuthType != types.ApiKeyAuthType {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
}).Errorf("bouncer %s attempted to login using an API key but it is configured to auth with %s", bouncer.Name, bouncer.AuthType)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
} }
if bouncer == nil { if bouncer == nil {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
}).Errorf("bouncer not found")
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
return return
} }
//maybe we want to store the whole bouncer object in the context instead, this would avoid another db query
//in StreamDecision
c.Set("BOUNCER_NAME", bouncer.Name) c.Set("BOUNCER_NAME", bouncer.Name)
c.Set("BOUNCER_HASHED_KEY", bouncer.APIKey)
if bouncer.IPAddress == "" { if bouncer.IPAddress == "" {
err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID)
if err != nil { if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) log.WithFields(log.Fields{
"ip": c.ClientIP(),
"name": bouncer.Name,
}).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
return return
@ -87,7 +189,10 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress) log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress)
err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID) err = a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID)
if err != nil { if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err) log.WithFields(log.Fields{
"ip": c.ClientIP(),
"name": bouncer.Name,
}).Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"}) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort() c.Abort()
return return
@ -97,13 +202,19 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
useragent := strings.Split(c.Request.UserAgent(), "/") useragent := strings.Split(c.Request.UserAgent(), "/")
if len(useragent) != 2 { if len(useragent) != 2 {
log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), c.ClientIP()) log.WithFields(log.Fields{
"ip": c.ClientIP(),
"name": bouncer.Name,
}).Warningf("bad user agent '%s'", c.Request.UserAgent())
useragent = []string{c.Request.UserAgent(), "N/A"} useragent = []string{c.Request.UserAgent(), "N/A"}
} }
if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] { if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] {
if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil { if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil {
log.Errorf("failed to update bouncer version and type from '%s' (%s): %s", c.Request.UserAgent(), c.ClientIP(), err) log.WithFields(log.Fields{
"ip": c.ClientIP(),
"name": bouncer.Name,
}).Errorf("failed to update bouncer version and type: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"}) c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"})
c.Abort() c.Abort()
return return

View file

@ -3,14 +3,17 @@ package v1
import ( import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"net/http"
"os" "os"
"strings" "strings"
"time" "time"
jwt "github.com/appleboy/gin-jwt/v2" jwt "github.com/appleboy/gin-jwt/v2"
"github.com/crowdsecurity/crowdsec/pkg/database" "github.com/crowdsecurity/crowdsec/pkg/database"
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine" "github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-openapi/strfmt" "github.com/go-openapi/strfmt"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -23,6 +26,7 @@ var identityKey = "id"
type JWT struct { type JWT struct {
Middleware *jwt.GinJWTMiddleware Middleware *jwt.GinJWTMiddleware
DbClient *database.Client DbClient *database.Client
TlsAuth *TLSAuth
} }
func PayloadFunc(data interface{}) jwt.MapClaims { func PayloadFunc(data interface{}) jwt.MapClaims {
@ -46,35 +50,109 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
var loginInput models.WatcherAuthRequest var loginInput models.WatcherAuthRequest
var scenarios string var scenarios string
var err error var err error
if err := c.ShouldBindJSON(&loginInput); err != nil { var scenariosInput []string
return "", errors.Wrap(err, "missing") var clientMachine *ent.Machine
} var machineID string
if err := loginInput.Validate(strfmt.Default); err != nil { var password strfmt.Password
return "", errors.New("input format error")
}
machineID := *loginInput.MachineID
password := *loginInput.Password
scenariosInput := loginInput.Scenarios
machine, err := j.DbClient.Ent.Machine.Query(). if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
Where(machine.MachineId(machineID)). if j.TlsAuth == nil {
First(j.DbClient.CTX) c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
if err != nil { c.Abort()
log.Printf("Error machine login for %s : %+v ", machineID, err) return nil, errors.New("TLS auth is not configured")
return nil, err }
} validCert, extractedCN, err := j.TlsAuth.ValidateCert(c)
if err != nil {
log.Error(err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return nil, errors.Wrap(err, "while trying to validate client cert")
}
if !validCert {
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return nil, fmt.Errorf("failed cert authentication")
}
if machine == nil { machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
log.Errorf("Nothing for '%s'", machineID) clientMachine, err = j.DbClient.Ent.Machine.Query().
return nil, jwt.ErrFailedAuthentication Where(machine.MachineId(machineID)).
} First(j.DbClient.CTX)
if ent.IsNotFound(err) {
//Machine was not found, let's create it
log.Printf("machine %s not found, create it", machineID)
//let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli)
pwd, err := GenerateAPIKey(64)
if err != nil {
log.WithFields(log.Fields{
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("error generating password: %s", err)
return nil, fmt.Errorf("error generating password")
}
password := strfmt.Password(pwd)
clientMachine, err = j.DbClient.CreateMachine(&machineID, &password, "", true, true, types.TlsAuthType)
if err != nil {
return "", errors.Wrapf(err, "while creating machine entry for %s", machineID)
}
} else if err != nil {
return "", errors.Wrapf(err, "while selecting machine entry for %s", machineID)
} else {
if clientMachine.AuthType != types.TlsAuthType {
return "", errors.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", machineID, clientMachine.AuthType)
}
machineID = clientMachine.MachineId
loginInput := struct {
Scenarios []string `json:"scenarios"`
}{
Scenarios: []string{},
}
err := c.ShouldBindJSON(&loginInput)
if err != nil {
return "", errors.Wrap(err, "missing scenarios list in login request for TLS auth")
}
scenariosInput = loginInput.Scenarios
}
if !machine.IsValidated { } else {
return nil, fmt.Errorf("machine %s not validated", machineID) //normal auth
}
if err = bcrypt.CompareHashAndPassword([]byte(machine.Password), []byte(password)); err != nil { if err := c.ShouldBindJSON(&loginInput); err != nil {
return nil, jwt.ErrFailedAuthentication return "", errors.Wrap(err, "missing")
}
if err := loginInput.Validate(strfmt.Default); err != nil {
return "", errors.New("input format error")
}
machineID = *loginInput.MachineID
password = *loginInput.Password
scenariosInput = loginInput.Scenarios
clientMachine, err = j.DbClient.Ent.Machine.Query().
Where(machine.MachineId(machineID)).
First(j.DbClient.CTX)
if err != nil {
log.Printf("Error machine login for %s : %+v ", machineID, err)
return nil, err
}
if clientMachine == nil {
log.Errorf("Nothing for '%s'", machineID)
return nil, jwt.ErrFailedAuthentication
}
if clientMachine.AuthType != types.PasswordAuthType {
return nil, errors.Errorf("machine %s attempted to auth with password but it is configured to use %s", machineID, clientMachine.AuthType)
}
if !clientMachine.IsValidated {
return nil, fmt.Errorf("machine %s not validated", machineID)
}
if err = bcrypt.CompareHashAndPassword([]byte(clientMachine.Password), []byte(password)); err != nil {
return nil, jwt.ErrFailedAuthentication
}
//end of normal auth
} }
if len(scenariosInput) > 0 { if len(scenariosInput) > 0 {
@ -85,26 +163,26 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
scenarios += "," + scenario scenarios += "," + scenario
} }
} }
err = j.DbClient.UpdateMachineScenarios(scenarios, machine.ID) err = j.DbClient.UpdateMachineScenarios(scenarios, clientMachine.ID)
if err != nil { if err != nil {
log.Errorf("Failed to update scenarios list for '%s': %s\n", machineID, err) log.Errorf("Failed to update scenarios list for '%s': %s\n", machineID, err)
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
} }
} }
if machine.IpAddress == "" { if clientMachine.IpAddress == "" {
err = j.DbClient.UpdateMachineIP(c.ClientIP(), machine.ID) err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID)
if err != nil { if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", machineID, err) log.Errorf("Failed to update ip address for '%s': %s\n", machineID, err)
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
} }
} }
if machine.IpAddress != c.ClientIP() && machine.IpAddress != "" { if clientMachine.IpAddress != c.ClientIP() && clientMachine.IpAddress != "" {
log.Warningf("new IP address detected for machine '%s': %s (old: %s)", machine.MachineId, c.ClientIP(), machine.IpAddress) log.Warningf("new IP address detected for machine '%s': %s (old: %s)", clientMachine.MachineId, c.ClientIP(), clientMachine.IpAddress)
err = j.DbClient.UpdateMachineIP(c.ClientIP(), machine.ID) err = j.DbClient.UpdateMachineIP(c.ClientIP(), clientMachine.ID)
if err != nil { if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", machine.MachineId, err) log.Errorf("Failed to update ip address for '%s': %s\n", clientMachine.MachineId, err)
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
} }
} }
@ -115,12 +193,11 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
} }
if err := j.DbClient.UpdateMachineVersion(useragent[1], machine.ID); err != nil { if err := j.DbClient.UpdateMachineVersion(useragent[1], clientMachine.ID); err != nil {
log.Errorf("unable to update machine '%s' version '%s': %s", machine.MachineId, useragent[1], err) log.Errorf("unable to update machine '%s' version '%s': %s", clientMachine.MachineId, useragent[1], err)
log.Errorf("bad user agent from : %s", c.ClientIP()) log.Errorf("bad user agent from : %s", c.ClientIP())
return nil, jwt.ErrFailedAuthentication return nil, jwt.ErrFailedAuthentication
} }
return &models.WatcherAuthRequest{ return &models.WatcherAuthRequest{
MachineID: &machineID, MachineID: &machineID,
}, nil }, nil
@ -178,6 +255,7 @@ func NewJWT(dbClient *database.Client) (*JWT, error) {
jwtMiddleware := &JWT{ jwtMiddleware := &JWT{
DbClient: dbClient, DbClient: dbClient,
TlsAuth: &TLSAuth{},
} }
ret, err := jwt.New(&jwt.GinJWTMiddleware{ ret, err := jwt.New(&jwt.GinJWTMiddleware{
@ -195,15 +273,15 @@ func NewJWT(dbClient *database.Client) (*JWT, error) {
TokenHeadName: "Bearer", TokenHeadName: "Bearer",
TimeFunc: time.Now, TimeFunc: time.Now,
}) })
if err != nil {
return &JWT{}, err
}
errInit := ret.MiddlewareInit() errInit := ret.MiddlewareInit()
if errInit != nil { if errInit != nil {
return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error()) return &JWT{}, fmt.Errorf("authMiddleware.MiddlewareInit() Error:" + errInit.Error())
} }
jwtMiddleware.Middleware = ret
if err != nil { return jwtMiddleware, nil
return &JWT{}, err
}
return &JWT{Middleware: ret}, nil
} }

View file

@ -18,6 +18,5 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) {
} }
ret.APIKey = NewAPIKey(dbClient) ret.APIKey = NewAPIKey(dbClient)
return ret, nil return ret, nil
} }

View file

@ -0,0 +1,256 @@
package v1
import (
"bytes"
"crypto"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ocsp"
)
type TLSAuth struct {
AllowedOUs []string
CrlPath string
revokationCache map[string]cacheEntry
cacheExpiration time.Duration
logger *log.Entry
}
type cacheEntry struct {
revoked bool
err error
timestamp time.Time
}
func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509.Certificate) (*ocsp.Response, error) {
req, err := ocsp.CreateRequest(cert, issuer, &ocsp.RequestOptions{Hash: crypto.SHA256})
if err != nil {
ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err)
return nil, err
}
httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req))
if err != nil {
ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP")
return nil, err
}
ocspURL, err := url.Parse(server)
if err != nil {
ta.logger.Error("TLSAuth: cannot parse OCSP URL")
return nil, err
}
httpRequest.Header.Add("Content-Type", "application/ocsp-request")
httpRequest.Header.Add("Accept", "application/ocsp-response")
httpRequest.Header.Add("host", ocspURL.Host)
httpClient := &http.Client{}
httpResponse, err := httpClient.Do(httpRequest)
if err != nil {
ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP")
return nil, err
}
defer httpResponse.Body.Close()
output, err := ioutil.ReadAll(httpResponse.Body)
if err != nil {
ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP")
return nil, err
}
ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer)
return ocspResponse, err
}
func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool {
now := time.Now().UTC()
if cert.NotAfter.UTC().Before(now) {
ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC())
return true
}
if cert.NotBefore.UTC().After(now) {
ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC())
return true
}
return false
}
func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
if cert.OCSPServer == nil || (cert.OCSPServer != nil && len(cert.OCSPServer) == 0) {
ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification")
return false, nil
}
for _, server := range cert.OCSPServer {
ocspResponse, err := ta.ocspQuery(server, cert, issuer)
if err != nil {
ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err)
continue
}
switch ocspResponse.Status {
case ocsp.Good:
return false, nil
case ocsp.Revoked:
return true, fmt.Errorf("client certificate is revoked by server %s", server)
case ocsp.Unknown:
log.Debugf("unknow OCSP status for server %s", server)
continue
}
}
log.Infof("Could not get any valid OCSP response, assuming the cert is revoked")
return true, nil
}
func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) {
if ta.CrlPath == "" {
ta.logger.Warn("no crl_path, skipping CRL check")
return false, nil
}
crlContent, err := ioutil.ReadFile(ta.CrlPath)
if err != nil {
ta.logger.Warnf("could not read CRL file, skipping check: %s", err)
return false, nil
}
crl, err := x509.ParseCRL(crlContent)
if err != nil {
ta.logger.Warnf("could not parse CRL file, skipping check: %s", err)
return false, nil
}
if crl.HasExpired(time.Now().UTC()) {
ta.logger.Warn("CRL has expired, will still validate the cert against it.")
}
for _, revoked := range crl.TBSCertList.RevokedCertificates {
if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
return true, fmt.Errorf("client certificate is revoked by CRL")
}
}
return false, nil
}
func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
sn := cert.SerialNumber.String()
if cacheValue, ok := ta.revokationCache[sn]; ok {
if time.Now().UTC().Sub(cacheValue.timestamp) < ta.cacheExpiration {
ta.logger.Debugf("TLSAuth: using cached value for cert %s: %t | %s", sn, cacheValue.revoked, cacheValue.err)
return cacheValue.revoked, cacheValue.err
} else {
ta.logger.Debugf("TLSAuth: cached value expired, removing from cache")
delete(ta.revokationCache, sn)
}
} else {
ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn)
}
revoked, err := ta.isOCSPRevoked(cert, issuer)
if err != nil {
ta.revokationCache[sn] = cacheEntry{
revoked: revoked,
err: err,
timestamp: time.Now().UTC(),
}
return true, err
}
if revoked {
ta.revokationCache[sn] = cacheEntry{
revoked: revoked,
err: err,
timestamp: time.Now().UTC(),
}
return true, nil
}
revoked, err = ta.isCRLRevoked(cert)
ta.revokationCache[sn] = cacheEntry{
revoked: revoked,
err: err,
timestamp: time.Now().UTC(),
}
return revoked, err
}
func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (bool, error) {
if ta.isExpired(cert) {
return true, nil
}
revoked, err := ta.isRevoked(cert, issuer)
if err != nil {
//Fail securely, if we can't check the revokation status, let's consider the cert invalid
//We may change this in the future based on users feedback, but this seems the most sensible thing to do
return true, errors.Wrap(err, "could not check for client certification revokation status")
}
return revoked, nil
}
func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error {
for _, ou := range allowedOus {
//disallow empty ou
if ou == "" {
return fmt.Errorf("empty ou isn't allowed")
}
//drop & warn on duplicate ou
ok := true
for _, validOu := range ta.AllowedOUs {
if validOu == ou {
ta.logger.Warningf("dropping duplicate ou %s", ou)
ok = false
}
}
if ok {
ta.AllowedOUs = append(ta.AllowedOUs, ou)
}
}
return nil
}
func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
//Checks cert validity, Returns true + CN if client cert matches requested OU
var clientCert *x509.Certificate
if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 {
//do not error if it's not TLS or there are no peer certs
return false, "", nil
}
if len(c.Request.TLS.VerifiedChains) > 0 {
validOU := false
clientCert = c.Request.TLS.VerifiedChains[0][0]
for _, ou := range clientCert.Subject.OrganizationalUnit {
for _, allowedOu := range ta.AllowedOUs {
if allowedOu == ou {
validOU = true
break
}
}
}
if !validOU {
return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)",
clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
}
revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1])
if err != nil {
ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err)
return false, "", errors.Wrap(err, "could not check for client certification revokation status")
}
if revoked {
return false, "", fmt.Errorf("client certificate is revoked")
}
ta.logger.Infof("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
return true, clientCert.Subject.CommonName, nil
}
return false, "", fmt.Errorf("no verified cert in request")
}
func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Duration, logger *log.Entry) (*TLSAuth, error) {
ta := &TLSAuth{
revokationCache: map[string]cacheEntry{},
cacheExpiration: cacheExpiration,
CrlPath: crlPath,
logger: logger,
}
err := ta.SetAllowedOu(allowedOus)
if err != nil {
return nil, err
}
return ta, nil
}

27
pkg/apiserver/utils.go Normal file
View file

@ -0,0 +1,27 @@
package apiserver
import (
"crypto/tls"
"fmt"
log "github.com/sirupsen/logrus"
)
func getTLSAuthType(authType string) (tls.ClientAuthType, error) {
switch authType {
case "NoClientCert":
return tls.NoClientCert, nil
case "RequestClientCert":
log.Warn("RequestClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
return tls.RequestClientCert, nil
case "RequireAnyClientCert":
log.Warn("RequireAnyClientCert is insecure, please use VerifyClientCertIfGiven or RequireAndVerifyClientCert instead")
return tls.RequireAnyClientCert, nil
case "VerifyClientCertIfGiven":
return tls.VerifyClientCertIfGiven, nil
case "RequireAndVerifyClientCert":
return tls.RequireAndVerifyClientCert, nil
default:
return 0, fmt.Errorf("unknown TLS client_verification value: %s", authType)
}
}

View file

@ -1,10 +1,13 @@
package csconfig package csconfig
import ( import (
"crypto/tls"
"crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
"strings" "strings"
"time"
"github.com/crowdsecurity/crowdsec/pkg/apiclient" "github.com/crowdsecurity/crowdsec/pkg/apiclient"
"github.com/crowdsecurity/crowdsec/pkg/yamlpatch" "github.com/crowdsecurity/crowdsec/pkg/yamlpatch"
@ -19,9 +22,12 @@ type APICfg struct {
} }
type ApiCredentialsCfg struct { type ApiCredentialsCfg struct {
URL string `yaml:"url,omitempty" json:"url,omitempty"` URL string `yaml:"url,omitempty" json:"url,omitempty"`
Login string `yaml:"login,omitempty" json:"login,omitempty"` Login string `yaml:"login,omitempty" json:"login,omitempty"`
Password string `yaml:"password,omitempty" json:"-"` Password string `yaml:"password,omitempty" json:"-"`
CACertPath string `yaml:"ca_cert_path,omitempty"`
KeyPath string `yaml:"key_path,omitempty"`
CertPath string `yaml:"cert_path,omitempty"`
} }
/*global api config (for lapi->oapi)*/ /*global api config (for lapi->oapi)*/
@ -73,11 +79,34 @@ func (l *LocalApiClientCfg) Load() error {
l.Credentials.URL = l.Credentials.URL + "/" l.Credentials.URL = l.Credentials.URL + "/"
} }
} }
if l.Credentials.Login != "" && (l.Credentials.CACertPath != "" || l.Credentials.CertPath != "" || l.Credentials.KeyPath != "") {
return fmt.Errorf("user/password authentication and TLS authentication are mutually exclusive")
}
if l.InsecureSkipVerify == nil { if l.InsecureSkipVerify == nil {
apiclient.InsecureSkipVerify = false apiclient.InsecureSkipVerify = false
} else { } else {
apiclient.InsecureSkipVerify = *l.InsecureSkipVerify apiclient.InsecureSkipVerify = *l.InsecureSkipVerify
} }
if l.Credentials.CACertPath != "" && l.Credentials.CertPath != "" && l.Credentials.KeyPath != "" {
cert, err := tls.LoadX509KeyPair(l.Credentials.CertPath, l.Credentials.KeyPath)
if err != nil {
return errors.Wrapf(err, "failed to load api client certificate")
}
caCert, err := ioutil.ReadFile(l.Credentials.CACertPath)
if err != nil {
return errors.Wrapf(err, "failed to load cacert")
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
apiclient.Cert = &cert
apiclient.CaCertPool = caCertPool
}
return nil return nil
} }
@ -128,8 +157,15 @@ type LocalApiServerCfg struct {
} }
type TLSCfg struct { type TLSCfg struct {
CertFilePath string `yaml:"cert_file"` CertFilePath string `yaml:"cert_file"`
KeyFilePath string `yaml:"key_file"` KeyFilePath string `yaml:"key_file"`
ClientVerification string `yaml:"client_verification,omitempty"`
ServerName string `yaml:"server_name"`
CACertPath string `yaml:"ca_cert_path"`
AllowedAgentsOU []string `yaml:"agents_allowed_ou"`
AllowedBouncersOU []string `yaml:"bouncers_allowed_ou"`
CRLPath string `yaml:"crl_path"`
CacheExpiration *time.Duration `yaml:"cache_expiration,omitempty"`
} }
func (c *Config) LoadAPIServer() error { func (c *Config) LoadAPIServer() error {

View file

@ -2,6 +2,7 @@ package csconfig
import ( import (
"fmt" "fmt"
"time"
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -23,9 +24,20 @@ type DatabaseCfg struct {
MaxOpenConns *int `yaml:"max_open_conns,omitempty"` MaxOpenConns *int `yaml:"max_open_conns,omitempty"`
} }
type AuthGCCfg struct {
Cert *string `yaml:"cert,omitempty"`
CertDuration *time.Duration
Api *string `yaml:"api_key,omitempty"`
ApiDuration *time.Duration
LoginPassword *string `yaml:"login_password,omitempty"`
LoginPasswordDuration *time.Duration
}
type FlushDBCfg struct { type FlushDBCfg struct {
MaxItems *int `yaml:"max_items"` MaxItems *int `yaml:"max_items,omitempty"`
MaxAge *string `yaml:"max_age"` MaxAge *string `yaml:"max_age,omitempty"`
BouncersGC *AuthGCCfg `yaml:"bouncers_autodelete,omitempty"`
AgentsGC *AuthGCCfg `yaml:"agents_autodelete,omitempty"`
} }
func (c *Config) LoadDBConfig() error { func (c *Config) LoadDBConfig() error {

View file

@ -7,10 +7,13 @@ import (
"strings" "strings"
"time" "time"
"github.com/crowdsecurity/crowdsec/pkg/csconfig"
"github.com/crowdsecurity/crowdsec/pkg/database/ent" "github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/alert" "github.com/crowdsecurity/crowdsec/pkg/database/ent/alert"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/bouncer"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/decision" "github.com/crowdsecurity/crowdsec/pkg/database/ent/decision"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/event" "github.com/crowdsecurity/crowdsec/pkg/database/ent/event"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/meta" "github.com/crowdsecurity/crowdsec/pkg/database/ent/meta"
"github.com/crowdsecurity/crowdsec/pkg/models" "github.com/crowdsecurity/crowdsec/pkg/models"
"github.com/crowdsecurity/crowdsec/pkg/types" "github.com/crowdsecurity/crowdsec/pkg/types"
@ -890,6 +893,77 @@ func (c *Client) FlushOrphans() {
} }
} }
func (c *Client) FlushAgentsAndBouncers(agentsCfg *csconfig.AuthGCCfg, bouncersCfg *csconfig.AuthGCCfg) error {
log.Printf("starting FlushAgentsAndBouncers")
if bouncersCfg != nil {
if bouncersCfg.ApiDuration != nil {
log.Printf("trying to delete old bouncers from api")
deletionCount, err := c.Ent.Bouncer.Delete().Where(
bouncer.LastPullLTE(time.Now().UTC().Add(*bouncersCfg.ApiDuration)),
).Where(
bouncer.AuthTypeEQ(types.ApiKeyAuthType),
).Exec(c.CTX)
if err != nil {
c.Log.Errorf("while auto-deleting expired bouncers (api key) : %s", err)
} else if deletionCount > 0 {
c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount)
}
}
if bouncersCfg.CertDuration != nil {
log.Printf("trying to delete old bouncers from cert")
deletionCount, err := c.Ent.Bouncer.Delete().Where(
bouncer.LastPullLTE(time.Now().UTC().Add(*bouncersCfg.CertDuration)),
).Where(
bouncer.AuthTypeEQ(types.TlsAuthType),
).Exec(c.CTX)
if err != nil {
c.Log.Errorf("while auto-deleting expired bouncers (api key) : %s", err)
} else if deletionCount > 0 {
c.Log.Infof("deleted %d expired bouncers (api auth)", deletionCount)
}
}
}
if agentsCfg != nil {
if agentsCfg.CertDuration != nil {
log.Printf("trying to delete old agents from cert")
deletionCount, err := c.Ent.Machine.Delete().Where(
machine.LastPushLTE(time.Now().UTC().Add(*agentsCfg.CertDuration)),
).Where(
machine.Not(machine.HasAlerts()),
).Where(
machine.AuthTypeEQ(types.TlsAuthType),
).Exec(c.CTX)
log.Printf("deleted %d entries", deletionCount)
if err != nil {
c.Log.Errorf("while auto-deleting expired machine (cert) : %s", err)
} else if deletionCount > 0 {
c.Log.Infof("deleted %d expired machine (cert auth)", deletionCount)
}
}
if agentsCfg.LoginPasswordDuration != nil {
log.Printf("trying to delete old agents from password")
deletionCount, err := c.Ent.Machine.Delete().Where(
machine.LastPushLTE(time.Now().UTC().Add(*agentsCfg.LoginPasswordDuration)),
).Where(
machine.Not(machine.HasAlerts()),
).Where(
machine.AuthTypeEQ(types.PasswordAuthType),
).Exec(c.CTX)
log.Printf("deleted %d entries", deletionCount)
if err != nil {
c.Log.Errorf("while auto-deleting expired machine (password) : %s", err)
} else if deletionCount > 0 {
c.Log.Infof("deleted %d expired machine (password auth)", deletionCount)
}
}
}
return nil
}
func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error { func (c *Client) FlushAlerts(MaxAge string, MaxItems int) error {
var deletedByAge int var deletedByAge int
var deletedByNbItem int var deletedByNbItem int

View file

@ -18,6 +18,15 @@ func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) {
return result, nil return result, nil
} }
func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX)
if err != nil {
return &ent.Bouncer{}, errors.Wrapf(QueryFail, "select bouncer: %s", err)
}
return result, nil
}
func (c *Client) ListBouncers() ([]*ent.Bouncer, error) { func (c *Client) ListBouncers() ([]*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().All(c.CTX) result, err := c.Ent.Bouncer.Query().All(c.CTX)
if err != nil { if err != nil {
@ -26,20 +35,21 @@ func (c *Client) ListBouncers() ([]*ent.Bouncer, error) {
return result, nil return result, nil
} }
func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string) error { func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) {
_, err := c.Ent.Bouncer. bouncer, err := c.Ent.Bouncer.
Create(). Create().
SetName(name). SetName(name).
SetAPIKey(apiKey). SetAPIKey(apiKey).
SetRevoked(false). SetRevoked(false).
SetAuthType(authType).
Save(c.CTX) Save(c.CTX)
if err != nil { if err != nil {
if ent.IsConstraintError(err) { if ent.IsConstraintError(err) {
return fmt.Errorf("bouncer %s already exists", name) return nil, fmt.Errorf("bouncer %s already exists", name)
} }
return fmt.Errorf("unable to save api key in database: %s", err) return nil, fmt.Errorf("unable to save api key in database: %s", err)
} }
return nil return bouncer, nil
} }
func (c *Client) DeleteBouncer(name string) error { func (c *Client) DeleteBouncer(name string) error {

View file

@ -122,14 +122,61 @@ func (c *Client) StartFlushScheduler(config *csconfig.FlushDBCfg) (*gocron.Sched
if config.MaxItems != nil { if config.MaxItems != nil {
maxItems = *config.MaxItems maxItems = *config.MaxItems
} }
if config.MaxAge != nil && *config.MaxAge != "" { if config.MaxAge != nil && *config.MaxAge != "" {
maxAge = *config.MaxAge maxAge = *config.MaxAge
} }
// Init & Start cronjob every minute
// Init & Start cronjob every minute for alerts
scheduler := gocron.NewScheduler(time.UTC) scheduler := gocron.NewScheduler(time.UTC)
job, _ := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems) job, err := scheduler.Every(1).Minute().Do(c.FlushAlerts, maxAge, maxItems)
if err != nil {
return nil, errors.Wrap(err, "while starting FlushAlerts scheduler")
}
job.SingletonMode() job.SingletonMode()
// Init & Start cronjob every hour for bouncers/agents
if config.AgentsGC != nil {
if config.AgentsGC.Cert != nil {
duration, err := types.ParseDuration(*config.AgentsGC.Cert)
if err != nil {
return nil, errors.Wrap(err, "while parsing agents cert auto-delete duration")
}
config.AgentsGC.CertDuration = &duration
}
if config.AgentsGC.LoginPassword != nil {
duration, err := types.ParseDuration(*config.AgentsGC.LoginPassword)
if err != nil {
return nil, errors.Wrap(err, "while parsing agents login/password auto-delete duration")
}
config.AgentsGC.LoginPasswordDuration = &duration
}
if config.AgentsGC.Api != nil {
log.Warningf("agents auto-delete for API auth is not supported (use cert or login_password)")
}
}
if config.BouncersGC != nil {
if config.BouncersGC.Cert != nil {
duration, err := types.ParseDuration(*config.BouncersGC.Cert)
if err != nil {
return nil, errors.Wrap(err, "while parsing bouncers cert auto-delete duration")
}
config.BouncersGC.CertDuration = &duration
}
if config.BouncersGC.Api != nil {
duration, err := types.ParseDuration(*config.BouncersGC.Api)
if err != nil {
return nil, errors.Wrap(err, "while parsing bouncers api auto-delete duration")
}
config.BouncersGC.ApiDuration = &duration
}
if config.BouncersGC.LoginPassword != nil {
log.Warningf("bouncers auto-delete for login/password auth is not supported (use cert or api)")
}
}
baJob, err := scheduler.Every(1).Minute().Do(c.FlushAgentsAndBouncers, config.AgentsGC, config.BouncersGC)
if err != nil {
return nil, errors.Wrap(err, "while starting FlushAgentsAndBouncers scheduler")
}
baJob.SingletonMode()
scheduler.StartAsync() scheduler.StartAsync()
return scheduler, nil return scheduler, nil

View file

@ -36,6 +36,8 @@ type Bouncer struct {
Until time.Time `json:"until"` Until time.Time `json:"until"`
// LastPull holds the value of the "last_pull" field. // LastPull holds the value of the "last_pull" field.
LastPull time.Time `json:"last_pull"` LastPull time.Time `json:"last_pull"`
// AuthType holds the value of the "auth_type" field.
AuthType string `json:"auth_type"`
} }
// scanValues returns the types for scanning values from sql.Rows. // scanValues returns the types for scanning values from sql.Rows.
@ -47,7 +49,7 @@ func (*Bouncer) scanValues(columns []string) ([]interface{}, error) {
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case bouncer.FieldID: case bouncer.FieldID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case bouncer.FieldName, bouncer.FieldAPIKey, bouncer.FieldIPAddress, bouncer.FieldType, bouncer.FieldVersion: case bouncer.FieldName, bouncer.FieldAPIKey, bouncer.FieldIPAddress, bouncer.FieldType, bouncer.FieldVersion, bouncer.FieldAuthType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case bouncer.FieldCreatedAt, bouncer.FieldUpdatedAt, bouncer.FieldUntil, bouncer.FieldLastPull: case bouncer.FieldCreatedAt, bouncer.FieldUpdatedAt, bouncer.FieldUntil, bouncer.FieldLastPull:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@ -134,6 +136,12 @@ func (b *Bouncer) assignValues(columns []string, values []interface{}) error {
} else if value.Valid { } else if value.Valid {
b.LastPull = value.Time b.LastPull = value.Time
} }
case bouncer.FieldAuthType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field auth_type", values[i])
} else if value.Valid {
b.AuthType = value.String
}
} }
} }
return nil return nil
@ -186,6 +194,8 @@ func (b *Bouncer) String() string {
builder.WriteString(b.Until.Format(time.ANSIC)) builder.WriteString(b.Until.Format(time.ANSIC))
builder.WriteString(", last_pull=") builder.WriteString(", last_pull=")
builder.WriteString(b.LastPull.Format(time.ANSIC)) builder.WriteString(b.LastPull.Format(time.ANSIC))
builder.WriteString(", auth_type=")
builder.WriteString(b.AuthType)
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }

View file

@ -31,6 +31,8 @@ const (
FieldUntil = "until" FieldUntil = "until"
// FieldLastPull holds the string denoting the last_pull field in the database. // FieldLastPull holds the string denoting the last_pull field in the database.
FieldLastPull = "last_pull" FieldLastPull = "last_pull"
// FieldAuthType holds the string denoting the auth_type field in the database.
FieldAuthType = "auth_type"
// Table holds the table name of the bouncer in the database. // Table holds the table name of the bouncer in the database.
Table = "bouncers" Table = "bouncers"
) )
@ -48,6 +50,7 @@ var Columns = []string{
FieldVersion, FieldVersion,
FieldUntil, FieldUntil,
FieldLastPull, FieldLastPull,
FieldAuthType,
} }
// ValidColumn reports if the column name is valid (part of the table columns). // ValidColumn reports if the column name is valid (part of the table columns).
@ -75,4 +78,6 @@ var (
DefaultUntil func() time.Time DefaultUntil func() time.Time
// DefaultLastPull holds the default value on creation for the "last_pull" field. // DefaultLastPull holds the default value on creation for the "last_pull" field.
DefaultLastPull func() time.Time DefaultLastPull func() time.Time
// DefaultAuthType holds the default value on creation for the "auth_type" field.
DefaultAuthType string
) )

View file

@ -162,6 +162,13 @@ func LastPull(v time.Time) predicate.Bouncer {
}) })
} }
// AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ.
func AuthType(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldAuthType), v))
})
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Bouncer { func CreatedAtEQ(v time.Time) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) { return predicate.Bouncer(func(s *sql.Selector) {
@ -1119,6 +1126,117 @@ func LastPullLTE(v time.Time) predicate.Bouncer {
}) })
} }
// AuthTypeEQ applies the EQ predicate on the "auth_type" field.
func AuthTypeEQ(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldAuthType), v))
})
}
// AuthTypeNEQ applies the NEQ predicate on the "auth_type" field.
func AuthTypeNEQ(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldAuthType), v))
})
}
// AuthTypeIn applies the In predicate on the "auth_type" field.
func AuthTypeIn(vs ...string) predicate.Bouncer {
v := make([]interface{}, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Bouncer(func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
// since we can't apply "IN ()". This will make this predicate falsy.
if len(v) == 0 {
s.Where(sql.False())
return
}
s.Where(sql.In(s.C(FieldAuthType), v...))
})
}
// AuthTypeNotIn applies the NotIn predicate on the "auth_type" field.
func AuthTypeNotIn(vs ...string) predicate.Bouncer {
v := make([]interface{}, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Bouncer(func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
// since we can't apply "IN ()". This will make this predicate falsy.
if len(v) == 0 {
s.Where(sql.False())
return
}
s.Where(sql.NotIn(s.C(FieldAuthType), v...))
})
}
// AuthTypeGT applies the GT predicate on the "auth_type" field.
func AuthTypeGT(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldAuthType), v))
})
}
// AuthTypeGTE applies the GTE predicate on the "auth_type" field.
func AuthTypeGTE(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldAuthType), v))
})
}
// AuthTypeLT applies the LT predicate on the "auth_type" field.
func AuthTypeLT(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldAuthType), v))
})
}
// AuthTypeLTE applies the LTE predicate on the "auth_type" field.
func AuthTypeLTE(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldAuthType), v))
})
}
// AuthTypeContains applies the Contains predicate on the "auth_type" field.
func AuthTypeContains(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.Contains(s.C(FieldAuthType), v))
})
}
// AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field.
func AuthTypeHasPrefix(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.HasPrefix(s.C(FieldAuthType), v))
})
}
// AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field.
func AuthTypeHasSuffix(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.HasSuffix(s.C(FieldAuthType), v))
})
}
// AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field.
func AuthTypeEqualFold(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.EqualFold(s.C(FieldAuthType), v))
})
}
// AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field.
func AuthTypeContainsFold(v string) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) {
s.Where(sql.ContainsFold(s.C(FieldAuthType), v))
})
}
// And groups predicates with the AND operator between them. // And groups predicates with the AND operator between them.
func And(predicates ...predicate.Bouncer) predicate.Bouncer { func And(predicates ...predicate.Bouncer) predicate.Bouncer {
return predicate.Bouncer(func(s *sql.Selector) { return predicate.Bouncer(func(s *sql.Selector) {

View file

@ -136,6 +136,20 @@ func (bc *BouncerCreate) SetNillableLastPull(t *time.Time) *BouncerCreate {
return bc return bc
} }
// SetAuthType sets the "auth_type" field.
func (bc *BouncerCreate) SetAuthType(s string) *BouncerCreate {
bc.mutation.SetAuthType(s)
return bc
}
// SetNillableAuthType sets the "auth_type" field if the given value is not nil.
func (bc *BouncerCreate) SetNillableAuthType(s *string) *BouncerCreate {
if s != nil {
bc.SetAuthType(*s)
}
return bc
}
// Mutation returns the BouncerMutation object of the builder. // Mutation returns the BouncerMutation object of the builder.
func (bc *BouncerCreate) Mutation() *BouncerMutation { func (bc *BouncerCreate) Mutation() *BouncerMutation {
return bc.mutation return bc.mutation
@ -227,6 +241,10 @@ func (bc *BouncerCreate) defaults() {
v := bouncer.DefaultLastPull() v := bouncer.DefaultLastPull()
bc.mutation.SetLastPull(v) bc.mutation.SetLastPull(v)
} }
if _, ok := bc.mutation.AuthType(); !ok {
v := bouncer.DefaultAuthType
bc.mutation.SetAuthType(v)
}
} }
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
@ -243,6 +261,9 @@ func (bc *BouncerCreate) check() error {
if _, ok := bc.mutation.LastPull(); !ok { if _, ok := bc.mutation.LastPull(); !ok {
return &ValidationError{Name: "last_pull", err: errors.New(`ent: missing required field "Bouncer.last_pull"`)} return &ValidationError{Name: "last_pull", err: errors.New(`ent: missing required field "Bouncer.last_pull"`)}
} }
if _, ok := bc.mutation.AuthType(); !ok {
return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Bouncer.auth_type"`)}
}
return nil return nil
} }
@ -350,6 +371,14 @@ func (bc *BouncerCreate) createSpec() (*Bouncer, *sqlgraph.CreateSpec) {
}) })
_node.LastPull = value _node.LastPull = value
} }
if value, ok := bc.mutation.AuthType(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldAuthType,
})
_node.AuthType = value
}
return _node, _spec return _node, _spec
} }

View file

@ -164,6 +164,20 @@ func (bu *BouncerUpdate) SetNillableLastPull(t *time.Time) *BouncerUpdate {
return bu return bu
} }
// SetAuthType sets the "auth_type" field.
func (bu *BouncerUpdate) SetAuthType(s string) *BouncerUpdate {
bu.mutation.SetAuthType(s)
return bu
}
// SetNillableAuthType sets the "auth_type" field if the given value is not nil.
func (bu *BouncerUpdate) SetNillableAuthType(s *string) *BouncerUpdate {
if s != nil {
bu.SetAuthType(*s)
}
return bu
}
// Mutation returns the BouncerMutation object of the builder. // Mutation returns the BouncerMutation object of the builder.
func (bu *BouncerUpdate) Mutation() *BouncerMutation { func (bu *BouncerUpdate) Mutation() *BouncerMutation {
return bu.mutation return bu.mutation
@ -360,6 +374,13 @@ func (bu *BouncerUpdate) sqlSave(ctx context.Context) (n int, err error) {
Column: bouncer.FieldLastPull, Column: bouncer.FieldLastPull,
}) })
} }
if value, ok := bu.mutation.AuthType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldAuthType,
})
}
if n, err = sqlgraph.UpdateNodes(ctx, bu.driver, _spec); err != nil { if n, err = sqlgraph.UpdateNodes(ctx, bu.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok { if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{bouncer.Label} err = &NotFoundError{bouncer.Label}
@ -515,6 +536,20 @@ func (buo *BouncerUpdateOne) SetNillableLastPull(t *time.Time) *BouncerUpdateOne
return buo return buo
} }
// SetAuthType sets the "auth_type" field.
func (buo *BouncerUpdateOne) SetAuthType(s string) *BouncerUpdateOne {
buo.mutation.SetAuthType(s)
return buo
}
// SetNillableAuthType sets the "auth_type" field if the given value is not nil.
func (buo *BouncerUpdateOne) SetNillableAuthType(s *string) *BouncerUpdateOne {
if s != nil {
buo.SetAuthType(*s)
}
return buo
}
// Mutation returns the BouncerMutation object of the builder. // Mutation returns the BouncerMutation object of the builder.
func (buo *BouncerUpdateOne) Mutation() *BouncerMutation { func (buo *BouncerUpdateOne) Mutation() *BouncerMutation {
return buo.mutation return buo.mutation
@ -735,6 +770,13 @@ func (buo *BouncerUpdateOne) sqlSave(ctx context.Context) (_node *Bouncer, err e
Column: bouncer.FieldLastPull, Column: bouncer.FieldLastPull,
}) })
} }
if value, ok := buo.mutation.AuthType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: bouncer.FieldAuthType,
})
}
_node = &Bouncer{config: buo.config} _node = &Bouncer{config: buo.config}
_spec.Assign = _node.assignValues _spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues _spec.ScanValues = _node.scanValues

View file

@ -38,6 +38,8 @@ type Machine struct {
IsValidated bool `json:"isValidated,omitempty"` IsValidated bool `json:"isValidated,omitempty"`
// Status holds the value of the "status" field. // Status holds the value of the "status" field.
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
// AuthType holds the value of the "auth_type" field.
AuthType string `json:"auth_type"`
// Edges holds the relations/edges for other nodes in the graph. // Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the MachineQuery when eager-loading is set. // The values are being populated by the MachineQuery when eager-loading is set.
Edges MachineEdges `json:"edges"` Edges MachineEdges `json:"edges"`
@ -70,7 +72,7 @@ func (*Machine) scanValues(columns []string) ([]interface{}, error) {
values[i] = new(sql.NullBool) values[i] = new(sql.NullBool)
case machine.FieldID: case machine.FieldID:
values[i] = new(sql.NullInt64) values[i] = new(sql.NullInt64)
case machine.FieldMachineId, machine.FieldPassword, machine.FieldIpAddress, machine.FieldScenarios, machine.FieldVersion, machine.FieldStatus: case machine.FieldMachineId, machine.FieldPassword, machine.FieldIpAddress, machine.FieldScenarios, machine.FieldVersion, machine.FieldStatus, machine.FieldAuthType:
values[i] = new(sql.NullString) values[i] = new(sql.NullString)
case machine.FieldCreatedAt, machine.FieldUpdatedAt, machine.FieldLastPush, machine.FieldLastHeartbeat: case machine.FieldCreatedAt, machine.FieldUpdatedAt, machine.FieldLastPush, machine.FieldLastHeartbeat:
values[i] = new(sql.NullTime) values[i] = new(sql.NullTime)
@ -165,6 +167,12 @@ func (m *Machine) assignValues(columns []string, values []interface{}) error {
} else if value.Valid { } else if value.Valid {
m.Status = value.String m.Status = value.String
} }
case machine.FieldAuthType:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field auth_type", values[i])
} else if value.Valid {
m.AuthType = value.String
}
} }
} }
return nil return nil
@ -227,6 +235,8 @@ func (m *Machine) String() string {
builder.WriteString(fmt.Sprintf("%v", m.IsValidated)) builder.WriteString(fmt.Sprintf("%v", m.IsValidated))
builder.WriteString(", status=") builder.WriteString(", status=")
builder.WriteString(m.Status) builder.WriteString(m.Status)
builder.WriteString(", auth_type=")
builder.WriteString(m.AuthType)
builder.WriteByte(')') builder.WriteByte(')')
return builder.String() return builder.String()
} }

View file

@ -33,6 +33,8 @@ const (
FieldIsValidated = "is_validated" FieldIsValidated = "is_validated"
// FieldStatus holds the string denoting the status field in the database. // FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status" FieldStatus = "status"
// FieldAuthType holds the string denoting the auth_type field in the database.
FieldAuthType = "auth_type"
// EdgeAlerts holds the string denoting the alerts edge name in mutations. // EdgeAlerts holds the string denoting the alerts edge name in mutations.
EdgeAlerts = "alerts" EdgeAlerts = "alerts"
// Table holds the table name of the machine in the database. // Table holds the table name of the machine in the database.
@ -60,6 +62,7 @@ var Columns = []string{
FieldVersion, FieldVersion,
FieldIsValidated, FieldIsValidated,
FieldStatus, FieldStatus,
FieldAuthType,
} }
// ValidColumn reports if the column name is valid (part of the table columns). // ValidColumn reports if the column name is valid (part of the table columns).
@ -93,4 +96,6 @@ var (
ScenariosValidator func(string) error ScenariosValidator func(string) error
// DefaultIsValidated holds the default value on creation for the "isValidated" field. // DefaultIsValidated holds the default value on creation for the "isValidated" field.
DefaultIsValidated bool DefaultIsValidated bool
// DefaultAuthType holds the default value on creation for the "auth_type" field.
DefaultAuthType string
) )

View file

@ -170,6 +170,13 @@ func Status(v string) predicate.Machine {
}) })
} }
// AuthType applies equality check predicate on the "auth_type" field. It's identical to AuthTypeEQ.
func AuthType(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldAuthType), v))
})
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field. // CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Machine { func CreatedAtEQ(v time.Time) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) { return predicate.Machine(func(s *sql.Selector) {
@ -1252,6 +1259,117 @@ func StatusContainsFold(v string) predicate.Machine {
}) })
} }
// AuthTypeEQ applies the EQ predicate on the "auth_type" field.
func AuthTypeEQ(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.EQ(s.C(FieldAuthType), v))
})
}
// AuthTypeNEQ applies the NEQ predicate on the "auth_type" field.
func AuthTypeNEQ(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.NEQ(s.C(FieldAuthType), v))
})
}
// AuthTypeIn applies the In predicate on the "auth_type" field.
func AuthTypeIn(vs ...string) predicate.Machine {
v := make([]interface{}, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Machine(func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
// since we can't apply "IN ()". This will make this predicate falsy.
if len(v) == 0 {
s.Where(sql.False())
return
}
s.Where(sql.In(s.C(FieldAuthType), v...))
})
}
// AuthTypeNotIn applies the NotIn predicate on the "auth_type" field.
func AuthTypeNotIn(vs ...string) predicate.Machine {
v := make([]interface{}, len(vs))
for i := range v {
v[i] = vs[i]
}
return predicate.Machine(func(s *sql.Selector) {
// if not arguments were provided, append the FALSE constants,
// since we can't apply "IN ()". This will make this predicate falsy.
if len(v) == 0 {
s.Where(sql.False())
return
}
s.Where(sql.NotIn(s.C(FieldAuthType), v...))
})
}
// AuthTypeGT applies the GT predicate on the "auth_type" field.
func AuthTypeGT(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.GT(s.C(FieldAuthType), v))
})
}
// AuthTypeGTE applies the GTE predicate on the "auth_type" field.
func AuthTypeGTE(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.GTE(s.C(FieldAuthType), v))
})
}
// AuthTypeLT applies the LT predicate on the "auth_type" field.
func AuthTypeLT(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.LT(s.C(FieldAuthType), v))
})
}
// AuthTypeLTE applies the LTE predicate on the "auth_type" field.
func AuthTypeLTE(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.LTE(s.C(FieldAuthType), v))
})
}
// AuthTypeContains applies the Contains predicate on the "auth_type" field.
func AuthTypeContains(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.Contains(s.C(FieldAuthType), v))
})
}
// AuthTypeHasPrefix applies the HasPrefix predicate on the "auth_type" field.
func AuthTypeHasPrefix(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.HasPrefix(s.C(FieldAuthType), v))
})
}
// AuthTypeHasSuffix applies the HasSuffix predicate on the "auth_type" field.
func AuthTypeHasSuffix(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.HasSuffix(s.C(FieldAuthType), v))
})
}
// AuthTypeEqualFold applies the EqualFold predicate on the "auth_type" field.
func AuthTypeEqualFold(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.EqualFold(s.C(FieldAuthType), v))
})
}
// AuthTypeContainsFold applies the ContainsFold predicate on the "auth_type" field.
func AuthTypeContainsFold(v string) predicate.Machine {
return predicate.Machine(func(s *sql.Selector) {
s.Where(sql.ContainsFold(s.C(FieldAuthType), v))
})
}
// HasAlerts applies the HasEdge predicate on the "alerts" edge. // HasAlerts applies the HasEdge predicate on the "alerts" edge.
func HasAlerts() predicate.Machine { func HasAlerts() predicate.Machine {
return predicate.Machine(func(s *sql.Selector) { return predicate.Machine(func(s *sql.Selector) {

View file

@ -151,6 +151,20 @@ func (mc *MachineCreate) SetNillableStatus(s *string) *MachineCreate {
return mc return mc
} }
// SetAuthType sets the "auth_type" field.
func (mc *MachineCreate) SetAuthType(s string) *MachineCreate {
mc.mutation.SetAuthType(s)
return mc
}
// SetNillableAuthType sets the "auth_type" field if the given value is not nil.
func (mc *MachineCreate) SetNillableAuthType(s *string) *MachineCreate {
if s != nil {
mc.SetAuthType(*s)
}
return mc
}
// AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs.
func (mc *MachineCreate) AddAlertIDs(ids ...int) *MachineCreate { func (mc *MachineCreate) AddAlertIDs(ids ...int) *MachineCreate {
mc.mutation.AddAlertIDs(ids...) mc.mutation.AddAlertIDs(ids...)
@ -257,6 +271,10 @@ func (mc *MachineCreate) defaults() {
v := machine.DefaultIsValidated v := machine.DefaultIsValidated
mc.mutation.SetIsValidated(v) mc.mutation.SetIsValidated(v)
} }
if _, ok := mc.mutation.AuthType(); !ok {
v := machine.DefaultAuthType
mc.mutation.SetAuthType(v)
}
} }
// check runs all checks and user-defined validators on the builder. // check runs all checks and user-defined validators on the builder.
@ -278,6 +296,9 @@ func (mc *MachineCreate) check() error {
if _, ok := mc.mutation.IsValidated(); !ok { if _, ok := mc.mutation.IsValidated(); !ok {
return &ValidationError{Name: "isValidated", err: errors.New(`ent: missing required field "Machine.isValidated"`)} return &ValidationError{Name: "isValidated", err: errors.New(`ent: missing required field "Machine.isValidated"`)}
} }
if _, ok := mc.mutation.AuthType(); !ok {
return &ValidationError{Name: "auth_type", err: errors.New(`ent: missing required field "Machine.auth_type"`)}
}
return nil return nil
} }
@ -393,6 +414,14 @@ func (mc *MachineCreate) createSpec() (*Machine, *sqlgraph.CreateSpec) {
}) })
_node.Status = value _node.Status = value
} }
if value, ok := mc.mutation.AuthType(); ok {
_spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: machine.FieldAuthType,
})
_node.AuthType = value
}
if nodes := mc.mutation.AlertsIDs(); len(nodes) > 0 { if nodes := mc.mutation.AlertsIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,

View file

@ -169,6 +169,20 @@ func (mu *MachineUpdate) ClearStatus() *MachineUpdate {
return mu return mu
} }
// SetAuthType sets the "auth_type" field.
func (mu *MachineUpdate) SetAuthType(s string) *MachineUpdate {
mu.mutation.SetAuthType(s)
return mu
}
// SetNillableAuthType sets the "auth_type" field if the given value is not nil.
func (mu *MachineUpdate) SetNillableAuthType(s *string) *MachineUpdate {
if s != nil {
mu.SetAuthType(*s)
}
return mu
}
// AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs.
func (mu *MachineUpdate) AddAlertIDs(ids ...int) *MachineUpdate { func (mu *MachineUpdate) AddAlertIDs(ids ...int) *MachineUpdate {
mu.mutation.AddAlertIDs(ids...) mu.mutation.AddAlertIDs(ids...)
@ -438,6 +452,13 @@ func (mu *MachineUpdate) sqlSave(ctx context.Context) (n int, err error) {
Column: machine.FieldStatus, Column: machine.FieldStatus,
}) })
} }
if value, ok := mu.mutation.AuthType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: machine.FieldAuthType,
})
}
if mu.mutation.AlertsCleared() { if mu.mutation.AlertsCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,
@ -651,6 +672,20 @@ func (muo *MachineUpdateOne) ClearStatus() *MachineUpdateOne {
return muo return muo
} }
// SetAuthType sets the "auth_type" field.
func (muo *MachineUpdateOne) SetAuthType(s string) *MachineUpdateOne {
muo.mutation.SetAuthType(s)
return muo
}
// SetNillableAuthType sets the "auth_type" field if the given value is not nil.
func (muo *MachineUpdateOne) SetNillableAuthType(s *string) *MachineUpdateOne {
if s != nil {
muo.SetAuthType(*s)
}
return muo
}
// AddAlertIDs adds the "alerts" edge to the Alert entity by IDs. // AddAlertIDs adds the "alerts" edge to the Alert entity by IDs.
func (muo *MachineUpdateOne) AddAlertIDs(ids ...int) *MachineUpdateOne { func (muo *MachineUpdateOne) AddAlertIDs(ids ...int) *MachineUpdateOne {
muo.mutation.AddAlertIDs(ids...) muo.mutation.AddAlertIDs(ids...)
@ -944,6 +979,13 @@ func (muo *MachineUpdateOne) sqlSave(ctx context.Context) (_node *Machine, err e
Column: machine.FieldStatus, Column: machine.FieldStatus,
}) })
} }
if value, ok := muo.mutation.AuthType(); ok {
_spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{
Type: field.TypeString,
Value: value,
Column: machine.FieldAuthType,
})
}
if muo.mutation.AlertsCleared() { if muo.mutation.AlertsCleared() {
edge := &sqlgraph.EdgeSpec{ edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M, Rel: sqlgraph.O2M,

View file

@ -69,6 +69,7 @@ var (
{Name: "version", Type: field.TypeString, Nullable: true}, {Name: "version", Type: field.TypeString, Nullable: true},
{Name: "until", Type: field.TypeTime, Nullable: true}, {Name: "until", Type: field.TypeTime, Nullable: true},
{Name: "last_pull", Type: field.TypeTime}, {Name: "last_pull", Type: field.TypeTime},
{Name: "auth_type", Type: field.TypeString, Default: "api-key"},
} }
// BouncersTable holds the schema information for the "bouncers" table. // BouncersTable holds the schema information for the "bouncers" table.
BouncersTable = &schema.Table{ BouncersTable = &schema.Table{
@ -163,6 +164,7 @@ var (
{Name: "version", Type: field.TypeString, Nullable: true}, {Name: "version", Type: field.TypeString, Nullable: true},
{Name: "is_validated", Type: field.TypeBool, Default: false}, {Name: "is_validated", Type: field.TypeBool, Default: false},
{Name: "status", Type: field.TypeString, Nullable: true}, {Name: "status", Type: field.TypeString, Nullable: true},
{Name: "auth_type", Type: field.TypeString, Default: "password"},
} }
// MachinesTable holds the schema information for the "machines" table. // MachinesTable holds the schema information for the "machines" table.
MachinesTable = &schema.Table{ MachinesTable = &schema.Table{

View file

@ -2338,6 +2338,7 @@ type BouncerMutation struct {
version *string version *string
until *time.Time until *time.Time
last_pull *time.Time last_pull *time.Time
auth_type *string
clearedFields map[string]struct{} clearedFields map[string]struct{}
done bool done bool
oldValue func(context.Context) (*Bouncer, error) oldValue func(context.Context) (*Bouncer, error)
@ -2880,6 +2881,42 @@ func (m *BouncerMutation) ResetLastPull() {
m.last_pull = nil m.last_pull = nil
} }
// SetAuthType sets the "auth_type" field.
func (m *BouncerMutation) SetAuthType(s string) {
m.auth_type = &s
}
// AuthType returns the value of the "auth_type" field in the mutation.
func (m *BouncerMutation) AuthType() (r string, exists bool) {
v := m.auth_type
if v == nil {
return
}
return *v, true
}
// OldAuthType returns the old "auth_type" field's value of the Bouncer entity.
// If the Bouncer object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *BouncerMutation) OldAuthType(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAuthType is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAuthType requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAuthType: %w", err)
}
return oldValue.AuthType, nil
}
// ResetAuthType resets all changes to the "auth_type" field.
func (m *BouncerMutation) ResetAuthType() {
m.auth_type = nil
}
// Where appends a list predicates to the BouncerMutation builder. // Where appends a list predicates to the BouncerMutation builder.
func (m *BouncerMutation) Where(ps ...predicate.Bouncer) { func (m *BouncerMutation) Where(ps ...predicate.Bouncer) {
m.predicates = append(m.predicates, ps...) m.predicates = append(m.predicates, ps...)
@ -2899,7 +2936,7 @@ func (m *BouncerMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *BouncerMutation) Fields() []string { func (m *BouncerMutation) Fields() []string {
fields := make([]string, 0, 10) fields := make([]string, 0, 11)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, bouncer.FieldCreatedAt) fields = append(fields, bouncer.FieldCreatedAt)
} }
@ -2930,6 +2967,9 @@ func (m *BouncerMutation) Fields() []string {
if m.last_pull != nil { if m.last_pull != nil {
fields = append(fields, bouncer.FieldLastPull) fields = append(fields, bouncer.FieldLastPull)
} }
if m.auth_type != nil {
fields = append(fields, bouncer.FieldAuthType)
}
return fields return fields
} }
@ -2958,6 +2998,8 @@ func (m *BouncerMutation) Field(name string) (ent.Value, bool) {
return m.Until() return m.Until()
case bouncer.FieldLastPull: case bouncer.FieldLastPull:
return m.LastPull() return m.LastPull()
case bouncer.FieldAuthType:
return m.AuthType()
} }
return nil, false return nil, false
} }
@ -2987,6 +3029,8 @@ func (m *BouncerMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldUntil(ctx) return m.OldUntil(ctx)
case bouncer.FieldLastPull: case bouncer.FieldLastPull:
return m.OldLastPull(ctx) return m.OldLastPull(ctx)
case bouncer.FieldAuthType:
return m.OldAuthType(ctx)
} }
return nil, fmt.Errorf("unknown Bouncer field %s", name) return nil, fmt.Errorf("unknown Bouncer field %s", name)
} }
@ -3066,6 +3110,13 @@ func (m *BouncerMutation) SetField(name string, value ent.Value) error {
} }
m.SetLastPull(v) m.SetLastPull(v)
return nil return nil
case bouncer.FieldAuthType:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAuthType(v)
return nil
} }
return fmt.Errorf("unknown Bouncer field %s", name) return fmt.Errorf("unknown Bouncer field %s", name)
} }
@ -3184,6 +3235,9 @@ func (m *BouncerMutation) ResetField(name string) error {
case bouncer.FieldLastPull: case bouncer.FieldLastPull:
m.ResetLastPull() m.ResetLastPull()
return nil return nil
case bouncer.FieldAuthType:
m.ResetAuthType()
return nil
} }
return fmt.Errorf("unknown Bouncer field %s", name) return fmt.Errorf("unknown Bouncer field %s", name)
} }
@ -5246,6 +5300,7 @@ type MachineMutation struct {
version *string version *string
isValidated *bool isValidated *bool
status *string status *string
auth_type *string
clearedFields map[string]struct{} clearedFields map[string]struct{}
alerts map[int]struct{} alerts map[int]struct{}
removedalerts map[int]struct{} removedalerts map[int]struct{}
@ -5840,6 +5895,42 @@ func (m *MachineMutation) ResetStatus() {
delete(m.clearedFields, machine.FieldStatus) delete(m.clearedFields, machine.FieldStatus)
} }
// SetAuthType sets the "auth_type" field.
func (m *MachineMutation) SetAuthType(s string) {
m.auth_type = &s
}
// AuthType returns the value of the "auth_type" field in the mutation.
func (m *MachineMutation) AuthType() (r string, exists bool) {
v := m.auth_type
if v == nil {
return
}
return *v, true
}
// OldAuthType returns the old "auth_type" field's value of the Machine entity.
// If the Machine object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *MachineMutation) OldAuthType(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldAuthType is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldAuthType requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldAuthType: %w", err)
}
return oldValue.AuthType, nil
}
// ResetAuthType resets all changes to the "auth_type" field.
func (m *MachineMutation) ResetAuthType() {
m.auth_type = nil
}
// AddAlertIDs adds the "alerts" edge to the Alert entity by ids. // AddAlertIDs adds the "alerts" edge to the Alert entity by ids.
func (m *MachineMutation) AddAlertIDs(ids ...int) { func (m *MachineMutation) AddAlertIDs(ids ...int) {
if m.alerts == nil { if m.alerts == nil {
@ -5913,7 +6004,7 @@ func (m *MachineMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call // order to get all numeric fields that were incremented/decremented, call
// AddedFields(). // AddedFields().
func (m *MachineMutation) Fields() []string { func (m *MachineMutation) Fields() []string {
fields := make([]string, 0, 11) fields := make([]string, 0, 12)
if m.created_at != nil { if m.created_at != nil {
fields = append(fields, machine.FieldCreatedAt) fields = append(fields, machine.FieldCreatedAt)
} }
@ -5947,6 +6038,9 @@ func (m *MachineMutation) Fields() []string {
if m.status != nil { if m.status != nil {
fields = append(fields, machine.FieldStatus) fields = append(fields, machine.FieldStatus)
} }
if m.auth_type != nil {
fields = append(fields, machine.FieldAuthType)
}
return fields return fields
} }
@ -5977,6 +6071,8 @@ func (m *MachineMutation) Field(name string) (ent.Value, bool) {
return m.IsValidated() return m.IsValidated()
case machine.FieldStatus: case machine.FieldStatus:
return m.Status() return m.Status()
case machine.FieldAuthType:
return m.AuthType()
} }
return nil, false return nil, false
} }
@ -6008,6 +6104,8 @@ func (m *MachineMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldIsValidated(ctx) return m.OldIsValidated(ctx)
case machine.FieldStatus: case machine.FieldStatus:
return m.OldStatus(ctx) return m.OldStatus(ctx)
case machine.FieldAuthType:
return m.OldAuthType(ctx)
} }
return nil, fmt.Errorf("unknown Machine field %s", name) return nil, fmt.Errorf("unknown Machine field %s", name)
} }
@ -6094,6 +6192,13 @@ func (m *MachineMutation) SetField(name string, value ent.Value) error {
} }
m.SetStatus(v) m.SetStatus(v)
return nil return nil
case machine.FieldAuthType:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetAuthType(v)
return nil
} }
return fmt.Errorf("unknown Machine field %s", name) return fmt.Errorf("unknown Machine field %s", name)
} }
@ -6221,6 +6326,9 @@ func (m *MachineMutation) ResetField(name string) error {
case machine.FieldStatus: case machine.FieldStatus:
m.ResetStatus() m.ResetStatus()
return nil return nil
case machine.FieldAuthType:
m.ResetAuthType()
return nil
} }
return fmt.Errorf("unknown Machine field %s", name) return fmt.Errorf("unknown Machine field %s", name)
} }

View file

@ -82,6 +82,10 @@ func init() {
bouncerDescLastPull := bouncerFields[9].Descriptor() bouncerDescLastPull := bouncerFields[9].Descriptor()
// bouncer.DefaultLastPull holds the default value on creation for the last_pull field. // bouncer.DefaultLastPull holds the default value on creation for the last_pull field.
bouncer.DefaultLastPull = bouncerDescLastPull.Default.(func() time.Time) bouncer.DefaultLastPull = bouncerDescLastPull.Default.(func() time.Time)
// bouncerDescAuthType is the schema descriptor for auth_type field.
bouncerDescAuthType := bouncerFields[10].Descriptor()
// bouncer.DefaultAuthType holds the default value on creation for the auth_type field.
bouncer.DefaultAuthType = bouncerDescAuthType.Default.(string)
decisionFields := schema.Decision{}.Fields() decisionFields := schema.Decision{}.Fields()
_ = decisionFields _ = decisionFields
// decisionDescCreatedAt is the schema descriptor for created_at field. // decisionDescCreatedAt is the schema descriptor for created_at field.
@ -152,6 +156,10 @@ func init() {
machineDescIsValidated := machineFields[9].Descriptor() machineDescIsValidated := machineFields[9].Descriptor()
// machine.DefaultIsValidated holds the default value on creation for the isValidated field. // machine.DefaultIsValidated holds the default value on creation for the isValidated field.
machine.DefaultIsValidated = machineDescIsValidated.Default.(bool) machine.DefaultIsValidated = machineDescIsValidated.Default.(bool)
// machineDescAuthType is the schema descriptor for auth_type field.
machineDescAuthType := machineFields[11].Descriptor()
// machine.DefaultAuthType holds the default value on creation for the auth_type field.
machine.DefaultAuthType = machineDescAuthType.Default.(string)
metaFields := schema.Meta{}.Fields() metaFields := schema.Meta{}.Fields()
_ = metaFields _ = metaFields
// metaDescCreatedAt is the schema descriptor for created_at field. // metaDescCreatedAt is the schema descriptor for created_at field.

View file

@ -29,6 +29,7 @@ func (Bouncer) Fields() []ent.Field {
field.Time("until").Default(types.UtcNow).Optional().StructTag(`json:"until"`), field.Time("until").Default(types.UtcNow).Optional().StructTag(`json:"until"`),
field.Time("last_pull"). field.Time("last_pull").
Default(types.UtcNow).StructTag(`json:"last_pull"`), Default(types.UtcNow).StructTag(`json:"last_pull"`),
field.String("auth_type").StructTag(`json:"auth_type"`).Default(types.ApiKeyAuthType),
} }
} }

View file

@ -35,6 +35,7 @@ func (Machine) Fields() []ent.Field {
field.Bool("isValidated"). field.Bool("isValidated").
Default(false), Default(false),
field.String("status").Optional(), field.String("status").Optional(),
field.String("auth_type").Default(types.PasswordAuthType).StructTag(`json:"auth_type"`),
} }
} }

View file

@ -14,11 +14,11 @@ import (
const CapiMachineID = "CAPI" const CapiMachineID = "CAPI"
func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool) (int, error) { func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipAddress string, isValidated bool, force bool, authType string) (*ent.Machine, error) {
hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost) hashPassword, err := bcrypt.GenerateFromPassword([]byte(*password), bcrypt.DefaultCost)
if err != nil { if err != nil {
c.Log.Warningf("CreateMachine : %s", err) c.Log.Warningf("CreateMachine : %s", err)
return 0, errors.Wrap(HashError, "") return nil, errors.Wrap(HashError, "")
} }
machineExist, err := c.Ent.Machine. machineExist, err := c.Ent.Machine.
@ -26,34 +26,39 @@ func (c *Client) CreateMachine(machineID *string, password *strfmt.Password, ipA
Where(machine.MachineIdEQ(*machineID)). Where(machine.MachineIdEQ(*machineID)).
Select(machine.FieldMachineId).Strings(c.CTX) Select(machine.FieldMachineId).Strings(c.CTX)
if err != nil { if err != nil {
return 0, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err) return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err)
} }
if len(machineExist) > 0 { if len(machineExist) > 0 {
if force { if force {
_, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX) _, err := c.Ent.Machine.Update().Where(machine.MachineIdEQ(*machineID)).SetPassword(string(hashPassword)).Save(c.CTX)
if err != nil { if err != nil {
c.Log.Warningf("CreateMachine : %s", err) c.Log.Warningf("CreateMachine : %s", err)
return 0, errors.Wrapf(UpdateFail, "machine '%s'", *machineID) return nil, errors.Wrapf(UpdateFail, "machine '%s'", *machineID)
} }
return 1, nil machine, err := c.QueryMachineByID(*machineID)
if err != nil {
return nil, errors.Wrapf(QueryFail, "machine '%s': %s", *machineID, err)
}
return machine, nil
} }
return 0, errors.Wrapf(UserExists, "user '%s'", *machineID) return nil, errors.Wrapf(UserExists, "user '%s'", *machineID)
} }
_, err = c.Ent.Machine. machine, err := c.Ent.Machine.
Create(). Create().
SetMachineId(*machineID). SetMachineId(*machineID).
SetPassword(string(hashPassword)). SetPassword(string(hashPassword)).
SetIpAddress(ipAddress). SetIpAddress(ipAddress).
SetIsValidated(isValidated). SetIsValidated(isValidated).
SetAuthType(authType).
Save(c.CTX) Save(c.CTX)
if err != nil { if err != nil {
c.Log.Warningf("CreateMachine : %s", err) c.Log.Warningf("CreateMachine : %s", err)
return 0, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID) return nil, errors.Wrapf(InsertFail, "creating machine '%s'", *machineID)
} }
return 1, nil return machine, nil
} }
func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) { func (c *Client) QueryMachineByID(machineID string) (*ent.Machine, error) {

5
pkg/types/constants.go Normal file
View file

@ -0,0 +1,5 @@
package types
const ApiKeyAuthType = "api-key"
const TlsAuthType = "tls"
const PasswordAuthType = "password"

View file

@ -37,5 +37,5 @@ declare stderr
@test "${FILE} CS_LAPI_SECRET not strong enough" { @test "${FILE} CS_LAPI_SECRET not strong enough" {
CS_LAPI_SECRET=foo run -1 --separate-stderr timeout 2s "${CROWDSEC}" CS_LAPI_SECRET=foo run -1 --separate-stderr timeout 2s "${CROWDSEC}"
run -0 echo "${stderr}" run -0 echo "${stderr}"
assert_output --partial "api server init: unable to run local API: CS_LAPI_SECRET not strong enough" assert_output --partial "api server init: unable to run local API: controller init: CS_LAPI_SECRET not strong enough"
} }

View file

@ -0,0 +1,97 @@
#!/usr/bin/env bats
# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si:
set -u
config_disable_agent() {
yq 'del(.crowdsec_service)' -i "${CONFIG_YAML}"
}
setup_file() {
load "../lib/setup_file.sh"
./instance-data load
tmpdir=$(mktemp -d)
export tmpdir
#gen the CA
cfssl gencert --initca ./cfssl/ca.json 2>/dev/null | cfssljson --bare "${tmpdir}/ca"
#gen an intermediate
cfssl gencert --initca ./cfssl/intermediate.json 2>/dev/null | cfssljson --bare "${tmpdir}/inter"
cfssl sign -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config ./cfssl/profiles.json -profile intermediate_ca "${tmpdir}/inter.csr" 2>/dev/null | cfssljson --bare "${tmpdir}/inter"
#gen server cert for crowdsec with the intermediate
cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=server ./cfssl/server.json 2>/dev/null | cfssljson --bare "${tmpdir}/server"
#gen client cert for the bouncer
cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/bouncer.json 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer"
#gen client cert for the bouncer with an invalid OU
cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/bouncer_invalid.json 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_bad_ou"
#gen client cert for the bouncer directly signed by the CA, it should be refused by crowdsec as uses the intermediate
cfssl gencert -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/bouncer.json 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_invalid"
cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/bouncer.json 2>/dev/null | cfssljson --bare "${tmpdir}/bouncer_revoked"
serial="$(openssl x509 -noout -serial -in ${tmpdir}/bouncer_revoked.pem | cut -d '=' -f2)"
echo "ibase=16; $serial" | bc > "${tmpdir}/serials.txt"
cfssl gencrl "${tmpdir}/serials.txt" "${tmpdir}/ca.pem" "${tmpdir}/ca-key.pem" | base64 -d | openssl crl -inform DER -out "${tmpdir}/crl.pem"
yq '
.api.server.tls.cert_file=strenv(tmpdir) + "/server.pem" |
.api.server.tls.key_file=strenv(tmpdir) + "/server-key.pem" |
.api.server.tls.ca_cert_path=strenv(tmpdir) + "/inter.pem" |
.api.server.tls.crl_path=strenv(tmpdir) + "/crl.pem" |
.api.server.tls.bouncers_allowed_ou=["bouncer-ou"]
' -i "${CONFIG_YAML}"
config_disable_agent
}
teardown_file() {
load "../lib/teardown_file.sh"
rm -rf $tmpdir
}
setup() {
load "../lib/setup.sh"
./instance-crowdsec start
}
teardown() {
./instance-crowdsec stop
}
#----------
@test "$FILE there are 0 bouncers" {
run -0 cscli bouncers list -o json
assert_output "[]"
}
@test "$FILE simulate one bouncer request with a valid cert" {
run -0 curl -s --cert "${tmpdir}/bouncer.pem" --key "${tmpdir}/bouncer-key.pem" --cacert "${tmpdir}/inter.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42
assert_output "null"
run -0 cscli bouncers list -o json
run -0 jq '. | length' <(output)
assert_output '1'
run -0 cscli bouncers list -o json
run -0 jq -r '.[] | .name' <(output)
assert_output "localhost@127.0.0.1"
run cscli bouncers delete localhost@127.0.0.1
}
@test "$FILE simulate one bouncer request with an invalid cert" {
run curl -s --cert "${tmpdir}/bouncer_invalid.pem" --key "${tmpdir}/bouncer_invalid-key.pem" --cacert "${tmpdir}/ca-key.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42
run -0 cscli bouncers list -o json
assert_output "[]"
}
@test "$FILE simulate one bouncer request with an invalid OU" {
run curl -s --cert "${tmpdir}/bouncer_bad_ou.pem" --key "${tmpdir}/bouncer_bad_ou-key.pem" --cacert "${tmpdir}/inter.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42
run -0 cscli bouncers list -o json
assert_output "[]"
}
@test "$FILE simulate one bouncer request with a revoked certificate" {
run -0 curl -i -s --cert "${tmpdir}/bouncer_revoked.pem" --key "${tmpdir}/bouncer_revoked-key.pem" --cacert "${tmpdir}/inter.pem" https://localhost:8080/v1/decisions\?ip=42.42.42.42
assert_output --partial "access forbidden"
run -0 cscli bouncers list -o json
assert_output "[]"
}

View file

@ -0,0 +1,136 @@
#!/usr/bin/env bats
# vim: ft=bats:list:ts=8:sts=4:sw=4:et:ai:si:
set -u
setup_file() {
load "../lib/setup_file.sh"
./instance-data load
tmpdir=$(mktemp -d)
export tmpdir
#gen the CA
cfssl gencert --initca ./cfssl/ca.json 2>/dev/null | cfssljson --bare "${tmpdir}/ca"
#gen an intermediate
cfssl gencert --initca ./cfssl/intermediate.json 2>/dev/null | cfssljson --bare "${tmpdir}/inter"
cfssl sign -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config ./cfssl/profiles.json -profile intermediate_ca "${tmpdir}/inter.csr" 2>/dev/null | cfssljson --bare "${tmpdir}/inter"
#gen server cert for crowdsec with the intermediate
cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=server ./cfssl/server.json 2>/dev/null | cfssljson --bare "${tmpdir}/server"
#gen client cert for the agent
cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/agent.json 2>/dev/null | cfssljson --bare "${tmpdir}/agent"
#gen client cert for the agent with an invalid OU
cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/agent_invalid.json 2>/dev/null | cfssljson --bare "${tmpdir}/agent_bad_ou"
#gen client cert for the agent directly signed by the CA, it should be refused by crowdsec as uses the intermediate
cfssl gencert -ca "${tmpdir}/ca.pem" -ca-key "${tmpdir}/ca-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/agent.json 2>/dev/null | cfssljson --bare "${tmpdir}/agent_invalid"
cfssl gencert -ca "${tmpdir}/inter.pem" -ca-key "${tmpdir}/inter-key.pem" -config ./cfssl/profiles.json -profile=client ./cfssl/agent.json 2>/dev/null | cfssljson --bare "${tmpdir}/agent_revoked"
serial="$(openssl x509 -noout -serial -in ${tmpdir}/agent_revoked.pem | cut -d '=' -f2)"
echo "ibase=16; $serial" | bc > "${tmpdir}/serials.txt"
cfssl gencrl "${tmpdir}/serials.txt" "${tmpdir}/ca.pem" "${tmpdir}/ca-key.pem" | base64 -d | openssl crl -inform DER -out "${tmpdir}/crl.pem"
yq '
.api.server.tls.cert_file=strenv(tmpdir) + "/server.pem" |
.api.server.tls.key_file=strenv(tmpdir) + "/server-key.pem" |
.api.server.tls.ca_cert_path=strenv(tmpdir) + "/inter.pem" |
.api.server.tls.crl_path=strenv(tmpdir) + "/crl.pem" |
.api.server.tls.agents_allowed_ou=["agent-ou"]
' -i "${CONFIG_YAML}"
}
teardown_file() {
load "../lib/teardown_file.sh"
}
setup() {
load "../lib/setup.sh"
cscli machines delete githubciXXXXXXXXXXXXXXXXXXXXXXXX
}
teardown() {
./instance-crowdsec stop
}
#----------
@test "$FILE invalid OU for agent" {
CONFIG_DIR=$(dirname ${CONFIG_YAML})
yq '
.ca_cert_path=strenv(tmpdir) + "/inter.pem" |
.key_path=strenv(tmpdir) + "/agent_bad_ou-key.pem" |
.cert_path=strenv(tmpdir) + "/agent_bad_ou.pem" |
.url="https://127.0.0.1:8080"
' -i "${CONFIG_DIR}/local_api_credentials.yaml"
yq 'del(.login)' -i "${CONFIG_DIR}/local_api_credentials.yaml"
yq 'del(.password)' -i "${CONFIG_DIR}/local_api_credentials.yaml"
./instance-crowdsec start
#let the agent start
sleep 2
run -0 cscli machines list -o json
assert_output '[]'
}
@test "$FILE we have exactly one machine registered with TLS" {
CONFIG_DIR=$(dirname ${CONFIG_YAML})
yq '
.ca_cert_path=strenv(tmpdir) + "/inter.pem" |
.key_path=strenv(tmpdir) + "/agent-key.pem" |
.cert_path=strenv(tmpdir) + "/agent.pem" |
.url="https://127.0.0.1:8080"
' -i "${CONFIG_DIR}/local_api_credentials.yaml"
yq 'del(.login)' -i "${CONFIG_DIR}/local_api_credentials.yaml"
yq 'del(.password)' -i "${CONFIG_DIR}/local_api_credentials.yaml"
./instance-crowdsec start
#let the agent start
sleep 2
run -0 cscli machines list -o json
run -0 jq -c '[. | length, .[0].machineId[0:32], .[0].isValidated, .[0].ipAddress, .[0].auth_type]' <(output)
assert_output '[1,"localhost@127.0.0.1",true,"127.0.0.1","tls"]'
cscli machines delete localhost@127.0.0.1
./instance-crowdsec stop
}
@test "$FILE invalid cert for agent" {
CONFIG_DIR=$(dirname ${CONFIG_YAML})
yq '
.ca_cert_path=strenv(tmpdir) + "/inter.pem" |
.key_path=strenv(tmpdir) + "/agent_invalid-key.pem" |
.cert_path=strenv(tmpdir) + "/agent_invalid.pem" |
.url="https://127.0.0.1:8080"
' -i "${CONFIG_DIR}/local_api_credentials.yaml"
yq 'del(.login)' -i "${CONFIG_DIR}/local_api_credentials.yaml"
yq 'del(.password)' -i "${CONFIG_DIR}/local_api_credentials.yaml"
./instance-crowdsec start
#let the agent start
sleep 2
run -0 cscli machines list -o json
assert_output '[]'
}
@test "$FILE revoked cert for agent" {
CONFIG_DIR=$(dirname ${CONFIG_YAML})
yq '
.ca_cert_path=strenv(tmpdir) + "/inter.pem" |
.key_path=strenv(tmpdir) + "/agent_revoked-key.pem" |
.cert_path=strenv(tmpdir) + "/agent_revoked.pem" |
.url="https://127.0.0.1:8080"
' -i "${CONFIG_DIR}/local_api_credentials.yaml"
yq 'del(.login)' -i "${CONFIG_DIR}/local_api_credentials.yaml"
yq 'del(.password)' -i "${CONFIG_DIR}/local_api_credentials.yaml"
./instance-crowdsec start
#let the agent start
sleep 2
run -0 cscli machines list -o json
assert_output '[]'
}

16
tests/cfssl/agent.json Normal file
View file

@ -0,0 +1,16 @@
{
"CN": "localhost",
"key": {
"algo": "rsa",
"size": 2048
},
"names": [
{
"C": "FR",
"L": "Paris",
"O": "Crowdsec",
"OU": "agent-ou",
"ST": "France"
}
]
}

View file

@ -0,0 +1,16 @@
{
"CN": "localhost",
"key": {
"algo": "rsa",
"size": 2048
},
"names": [
{
"C": "FR",
"L": "Paris",
"O": "Crowdsec",
"OU": "this-is-not-the-ou-youre-looking-for",
"ST": "France"
}
]
}

16
tests/cfssl/bouncer.json Normal file
View file

@ -0,0 +1,16 @@
{
"CN": "localhost",
"key": {
"algo": "rsa",
"size": 2048
},
"names": [
{
"C": "FR",
"L": "Paris",
"O": "Crowdsec",
"OU": "bouncer-ou",
"ST": "France"
}
]
}

View file

@ -0,0 +1,16 @@
{
"CN": "localhost",
"key": {
"algo": "rsa",
"size": 2048
},
"names": [
{
"C": "FR",
"L": "Paris",
"O": "Crowdsec",
"OU": "this-is-not-the-ou-youre-looking-for",
"ST": "France"
}
]
}

16
tests/cfssl/ca.json Normal file
View file

@ -0,0 +1,16 @@
{
"CN": "CrowdSec Test CA",
"key": {
"algo": "rsa",
"size": 2048
},
"names": [
{
"C": "FR",
"L": "Paris",
"O": "Crowdsec",
"OU": "Crowdsec",
"ST": "France"
}
]
}

View file

@ -0,0 +1,19 @@
{
"CN": "CrowdSec Test CA Intermediate",
"key": {
"algo": "rsa",
"size": 2048
},
"names": [
{
"C": "FR",
"L": "Paris",
"O": "Crowdsec",
"OU": "Crowdsec Intermediate",
"ST": "France"
}
],
"ca": {
"expiry": "42720h"
}
}

44
tests/cfssl/profiles.json Normal file
View file

@ -0,0 +1,44 @@
{
"signing": {
"default": {
"expiry": "8760h"
},
"profiles": {
"intermediate_ca": {
"usages": [
"signing",
"digital signature",
"key encipherment",
"cert sign",
"crl sign",
"server auth",
"client auth"
],
"expiry": "8760h",
"ca_constraint": {
"is_ca": true,
"max_path_len": 0,
"max_path_len_zero": true
}
},
"server": {
"usages": [
"signing",
"digital signing",
"key encipherment",
"server auth"
],
"expiry": "8760h"
},
"client": {
"usages": [
"signing",
"digital signature",
"key encipherment",
"client auth"
],
"expiry": "8760h"
}
}
}
}

20
tests/cfssl/server.json Normal file
View file

@ -0,0 +1,20 @@
{
"CN": "localhost",
"key": {
"algo": "rsa",
"size": 2048
},
"names": [
{
"C": "FR",
"L": "Paris",
"O": "Crowdsec",
"OU": "Crowdsec Server",
"ST": "France"
}
],
"hosts": [
"127.0.0.1",
"localhost"
]
}

View file

@ -68,6 +68,22 @@ check_daemonizer() {
esac esac
} }
check_cfssl() {
# shellcheck disable=SC2016
howto_install='You can install it with "go get -u github.com/cloudflare/cfssl/cmd/cfssl" and add ~/go/bin to $PATH.'
if ! command -v cfssl >/dev/null; then
die "Missing required program 'cfssl'. $howto_install"
fi
}
check_cfssljson() {
# shellcheck disable=SC2016
howto_install='You can install it with "go get -u github.com/cloudflare/cfssl/cmd/cfssljson" and add ~/go/bin to $PATH.'
if ! command -v cfssljson >/dev/null; then
die "Missing required program 'cfssljson'. $howto_install"
fi
}
check_gocovmerge() { check_gocovmerge() {
if ! command -v gocovmerge >/dev/null; then if ! command -v gocovmerge >/dev/null; then
die "missing required program 'gocovmerge'. You can install it with \"go install github.com/wadey/gocovmerge@latest\"" die "missing required program 'gocovmerge'. You can install it with \"go install github.com/wadey/gocovmerge@latest\""
@ -76,6 +92,8 @@ check_gocovmerge() {
check_bats_core check_bats_core
check_daemonizer check_daemonizer
check_cfssl
check_cfssljson
check_jq check_jq
check_nc check_nc
check_python3 check_python3