apiclient/apiserver: lint (#2739)

This commit is contained in:
mmetc 2024-01-15 11:44:38 +01:00 committed by GitHub
parent 03bb194d2c
commit 75d8ad9798
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 155 additions and 92 deletions

View file

@ -49,15 +49,15 @@ type AlertsDeleteOpts struct {
}
func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest) (*models.AddAlertsResponse, *Response, error) {
var addedIds models.AddAlertsResponse
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &alerts)
req, err := s.client.NewRequest(http.MethodPost, u, &alerts)
if err != nil {
return nil, nil, err
}
var addedIds models.AddAlertsResponse
resp, err := s.client.Do(ctx, req, &addedIds)
if err != nil {
return nil, resp, err
@ -68,22 +68,16 @@ func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest)
// to demo query arguments
func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.GetAlertsResponse, *Response, error) {
var (
alerts models.GetAlertsResponse
URI string
)
u := fmt.Sprintf("%s/alerts", s.client.URLPrefix)
params, err := qs.Values(opts)
params, err := qs.Values(opts)
if err != nil {
return nil, nil, fmt.Errorf("building query: %w", err)
}
URI := u
if len(params) > 0 {
URI = fmt.Sprintf("%s?%s", u, params.Encode())
} else {
URI = u
URI = fmt.Sprintf("%s?%s", URI, params.Encode())
}
req, err := s.client.NewRequest(http.MethodGet, URI, nil)
@ -91,6 +85,8 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.
return nil, nil, fmt.Errorf("building request: %w", err)
}
alerts := models.GetAlertsResponse{}
resp, err := s.client.Do(ctx, req, &alerts)
if err != nil {
return nil, resp, fmt.Errorf("performing request: %w", err)
@ -101,8 +97,6 @@ func (s *AlertsService) List(ctx context.Context, opts AlertsListOpts) (*models.
// to demo query arguments
func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*models.DeleteAlertsResponse, *Response, error) {
var alerts models.DeleteAlertsResponse
params, err := qs.Values(opts)
if err != nil {
return nil, nil, err
@ -115,6 +109,8 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod
return nil, nil, err
}
alerts := models.DeleteAlertsResponse{}
resp, err := s.client.Do(ctx, req, &alerts)
if err != nil {
return nil, resp, err
@ -124,8 +120,6 @@ func (s *AlertsService) Delete(ctx context.Context, opts AlertsDeleteOpts) (*mod
}
func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.DeleteAlertsResponse, *Response, error) {
var alerts models.DeleteAlertsResponse
u := fmt.Sprintf("%s/alerts/%s", s.client.URLPrefix, alertID)
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
@ -133,6 +127,8 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.
return nil, nil, err
}
alerts := models.DeleteAlertsResponse{}
resp, err := s.client.Do(ctx, req, &alerts)
if err != nil {
return nil, resp, err
@ -142,8 +138,6 @@ func (s *AlertsService) DeleteOne(ctx context.Context, alertID string) (*models.
}
func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert, *Response, error) {
var alert models.Alert
u := fmt.Sprintf("%s/alerts/%d", s.client.URLPrefix, alertID)
req, err := s.client.NewRequest(http.MethodGet, u, nil)
@ -151,6 +145,8 @@ func (s *AlertsService) GetByID(ctx context.Context, alertID int) (*models.Alert
return nil, nil, err
}
alert := models.Alert{}
resp, err := s.client.Do(ctx, req, &alert)
if err != nil {
return nil, nil, err

View file

@ -125,7 +125,7 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
select {
case <-req.Context().Done():
return resp, req.Context().Err()
return nil, req.Context().Err()
case <-time.After(time.Duration(backoff) * time.Second):
}
}
@ -135,8 +135,8 @@ func (r retryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
}
clonedReq := cloneRequest(req)
resp, err = r.next.RoundTrip(clonedReq)
resp, err = r.next.RoundTrip(clonedReq)
if err != nil {
if left := maxAttempts - i - 1; left > 0 {
log.Errorf("error while performing request: %s; %d retries left", err, left)
@ -171,10 +171,11 @@ type JWTTransport struct {
func (t *JWTTransport) refreshJwtToken() error {
var err error
if t.UpdateScenario != nil {
t.Scenarios, err = t.UpdateScenario()
if err != nil {
return fmt.Errorf("can't update scenario list: %s", err)
return fmt.Errorf("can't update scenario list: %w", err)
}
log.Debugf("scenarios list updated for '%s'", *t.MachineID)
@ -186,8 +187,6 @@ func (t *JWTTransport) refreshJwtToken() error {
Scenarios: t.Scenarios,
}
var response models.WatcherAuthResponse
/*
we don't use the main client, so let's build the body
*/
@ -250,6 +249,8 @@ func (t *JWTTransport) refreshJwtToken() error {
}
}
var response models.WatcherAuthResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return fmt.Errorf("unable to decode response: %w", err)
}
@ -324,14 +325,13 @@ func (t *JWTTransport) ResetToken() {
t.refreshTokenMutex.Unlock()
}
// transport() returns a round tripper that retries once when the status is unauthorized, and 5 times when the infrastructure is overloaded.
func (t *JWTTransport) transport() http.RoundTripper {
var transport http.RoundTripper
if t.Transport != nil {
transport = t.Transport
} else {
transport := t.Transport
if transport == nil {
transport = http.DefaultTransport
}
// a round tripper that retries once when the status is unauthorized and 5 times when infrastructure is overloaded
return &retryRoundTripper{
next: &retryRoundTripper{
next: transport,

View file

@ -3,11 +3,11 @@ package apiclient
import (
"bufio"
"context"
"errors"
"fmt"
"net/http"
qs "github.com/google/go-querystring/query"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/crowdsecurity/go-cs-lib/ptr"
@ -61,8 +61,6 @@ type DecisionsDeleteOpts struct {
// to demo query arguments
func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*models.GetDecisionsResponse, *Response, error) {
var decisions models.GetDecisionsResponse
params, err := qs.Values(opts)
if err != nil {
return nil, nil, err
@ -75,6 +73,8 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m
return nil, nil, err
}
var decisions models.GetDecisionsResponse
resp, err := s.client.Do(ctx, req, &decisions)
if err != nil {
return nil, resp, err
@ -84,13 +84,13 @@ func (s *DecisionsService) List(ctx context.Context, opts DecisionsListOpts) (*m
}
func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) {
var decisions models.DecisionsStreamResponse
req, err := s.client.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, nil, err
}
var decisions models.DecisionsStreamResponse
resp, err := s.client.Do(ctx, req, &decisions)
if err != nil {
return nil, resp, err
@ -100,7 +100,7 @@ func (s *DecisionsService) FetchV2Decisions(ctx context.Context, url string) (*m
}
func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi.GetDecisionsStreamResponseNewItem) []*models.Decision {
var decisions []*models.Decision
decisions := make([]*models.Decision, 0)
for _, decisionsGroup := range decisionsGroups {
partialDecisions := make([]*models.Decision, len(decisionsGroup.Decisions))
@ -122,11 +122,6 @@ func (s *DecisionsService) GetDecisionsFromGroups(decisionsGroups []*modelscapi.
}
func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*models.DecisionsStreamResponse, *Response, error) {
var (
decisions modelscapi.GetDecisionsStreamResponse
v2Decisions models.DecisionsStreamResponse
)
scenarioDeleted := "deleted"
durationDeleted := "1h"
@ -135,11 +130,14 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
return nil, nil, err
}
decisions := modelscapi.GetDecisionsStreamResponse{}
resp, err := s.client.Do(ctx, req, &decisions)
if err != nil {
return nil, resp, err
}
v2Decisions := models.DecisionsStreamResponse{}
v2Decisions.New = s.GetDecisionsFromGroups(decisions.New)
for _, decisionsGroup := range decisions.Deleted {
@ -183,6 +181,7 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
req = req.WithContext(ctx)
log.Debugf("[URL] %s %s", req.Method, req.URL)
// we don't use client_http Do method because we need the reader and is not provided.
// We would be forced to use Pipe and goroutine, etc
resp, err := client.Do(req)
@ -247,11 +246,11 @@ func (s *DecisionsService) GetStream(ctx context.Context, opts DecisionsStreamOp
return nil, nil, err
}
if s.client.URLPrefix == "v3" {
return s.FetchV3Decisions(ctx, u)
} else {
if s.client.URLPrefix != "v3" {
return s.FetchV2Decisions(ctx, u)
}
return s.FetchV3Decisions(ctx, u)
}
func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStreamOpts) (*modelscapi.GetDecisionsStreamResponse, *Response, error) {
@ -260,13 +259,13 @@ func (s *DecisionsService) GetStreamV3(ctx context.Context, opts DecisionsStream
return nil, nil, err
}
var decisions modelscapi.GetDecisionsStreamResponse
req, err := s.client.NewRequest(http.MethodGet, u, nil)
if err != nil {
return nil, nil, err
}
decisions := modelscapi.GetDecisionsStreamResponse{}
resp, err := s.client.Do(ctx, req, &decisions)
if err != nil {
return nil, resp, err
@ -292,8 +291,6 @@ func (s *DecisionsService) StopStream(ctx context.Context) (*Response, error) {
}
func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts) (*models.DeleteDecisionResponse, *Response, error) {
var deleteDecisionResponse models.DeleteDecisionResponse
params, err := qs.Values(opts)
if err != nil {
return nil, nil, err
@ -306,6 +303,8 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts)
return nil, nil, err
}
deleteDecisionResponse := models.DeleteDecisionResponse{}
resp, err := s.client.Do(ctx, req, &deleteDecisionResponse)
if err != nil {
return nil, resp, err
@ -315,8 +314,6 @@ func (s *DecisionsService) Delete(ctx context.Context, opts DecisionsDeleteOpts)
}
func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*models.DeleteDecisionResponse, *Response, error) {
var deleteDecisionResponse models.DeleteDecisionResponse
u := fmt.Sprintf("%s/decisions/%s", s.client.URLPrefix, decisionID)
req, err := s.client.NewRequest(http.MethodDelete, u, nil)
@ -324,6 +321,8 @@ func (s *DecisionsService) DeleteOne(ctx context.Context, decisionID string) (*m
return nil, nil, err
}
deleteDecisionResponse := models.DeleteDecisionResponse{}
resp, err := s.client.Do(ctx, req, &deleteDecisionResponse)
if err != nil {
return nil, resp, err

View file

@ -14,8 +14,6 @@ type DecisionDeleteService service
// DecisionDeleteService purposely reuses AddSignalsRequestItemDecisions model
func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *models.DecisionsDeleteRequest) (interface{}, *Response, error) {
var response interface{}
u := fmt.Sprintf("%s/decisions/delete", d.client.URLPrefix)
req, err := d.client.NewRequest(http.MethodPost, u, &deletedDecisions)
@ -23,6 +21,8 @@ func (d *DecisionDeleteService) Add(ctx context.Context, deletedDecisions *model
return nil, nil, fmt.Errorf("while building request: %w", err)
}
var response interface{}
resp, err := d.client.Do(ctx, req, &response)
if err != nil {
return nil, resp, fmt.Errorf("while performing request: %w", err)

View file

@ -11,8 +11,6 @@ import (
type MetricsService service
func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (interface{}, *Response, error) {
var response interface{}
u := fmt.Sprintf("%s/metrics/", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &metrics)
@ -20,6 +18,8 @@ func (s *MetricsService) Add(ctx context.Context, metrics *models.Metrics) (inte
return nil, nil, err
}
var response interface{}
resp, err := s.client.Do(ctx, req, &response)
if err != nil {
return nil, resp, err

View file

@ -13,8 +13,6 @@ import (
type SignalService service
func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsRequest) (interface{}, *Response, error) {
var response interface{}
u := fmt.Sprintf("%s/signals", s.client.URLPrefix)
req, err := s.client.NewRequest(http.MethodPost, u, &signals)
@ -22,6 +20,8 @@ func (s *SignalService) Add(ctx context.Context, signals *models.AddSignalsReque
return nil, nil, fmt.Errorf("while building request: %w", err)
}
var response interface{}
resp, err := s.client.Do(ctx, req, &response)
if err != nil {
return nil, resp, fmt.Errorf("while performing request: %w", err)

View file

@ -2,21 +2,21 @@ package apiserver
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"net/http"
"net/url"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/go-openapi/strfmt"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"gopkg.in/tomb.v2"
"slices"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/go-cs-lib/trace"
@ -652,12 +652,13 @@ func (a *apic) PullTop(forcePull bool) error {
addCounters, deleteCounters := makeAddAndDeleteCounters()
// process deleted decisions
if nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters); err != nil {
nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters)
if err != nil {
return err
} else {
log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted)
}
log.Printf("capi/community-blocklist : %d explicit deletions", nbDeleted)
if len(data.New) == 0 {
log.Infof("capi/community-blocklist : received 0 new entries (expected if you just installed crowdsec)")
return nil

View file

@ -2,6 +2,7 @@ package apiserver
import (
"context"
"errors"
"fmt"
"io"
"net"
@ -13,7 +14,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-co-op/gocron"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
"gopkg.in/tomb.v2"
@ -382,7 +382,9 @@ func (s *APIServer) listenAndServeURL(apiReady chan bool) {
if s.TLS.KeyFilePath == "" {
serverError <- errors.New("missing TLS key file")
return
} else if s.TLS.CertFilePath == "" {
}
if s.TLS.CertFilePath == "" {
serverError <- errors.New("missing TLS cert file")
return
}

View file

@ -59,8 +59,10 @@ func FormatOneAlert(alert *ent.Alert) *models.Alert {
}
for _, eventItem := range alert.Edges.Events {
var Metas models.Meta
timestamp := eventItem.Time.String()
var Metas models.Meta
if err := json.Unmarshal([]byte(eventItem.Serialized), &Metas); err != nil {
log.Errorf("unable to unmarshall events meta '%s' : %s", eventItem.Serialized, err)
}
@ -162,6 +164,7 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
if alert.Source.Scope != nil {
*alert.Source.Scope = normalizeScope(*alert.Source.Scope)
}
for _, decision := range alert.Decisions {
if decision.Scope != nil {
*decision.Scope = normalizeScope(*decision.Scope)
@ -183,30 +186,38 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
_, matched, err := profile.EvaluateProfile(alert)
if err != nil {
profile.Logger.Warningf("error while evaluating profile %s : %v", profile.Cfg.Name, err)
continue
}
if !matched {
continue
}
c.sendAlertToPluginChannel(alert, uint(pIdx))
if profile.Cfg.OnSuccess == "break" {
break
}
}
decision := alert.Decisions[0]
if decision.Origin != nil && *decision.Origin == types.CscliImportOrigin {
stopFlush = true
}
continue
}
for pIdx, profile := range c.Profiles {
profileDecisions, matched, err := profile.EvaluateProfile(alert)
forceBreak := false
if err != nil {
switch profile.Cfg.OnError {
case "apply":
profile.Logger.Warningf("applying profile %s despite error: %s", profile.Cfg.Name, err)
matched = true
case "continue":
profile.Logger.Warningf("skipping %s profile due to error: %s", profile.Cfg.Name, err)
@ -219,18 +230,23 @@ func (c *Controller) CreateAlert(gctx *gin.Context) {
return
}
}
if !matched {
continue
}
for _, decision := range profileDecisions {
decision.UUID = uuid.NewString()
}
// generate uuid here for alert
if len(alert.Decisions) == 0 { // non manual decision
alert.Decisions = append(alert.Decisions, profileDecisions...)
}
profileAlert := *alert
c.sendAlertToPluginChannel(&profileAlert, uint(pIdx))
if profile.Cfg.OnSuccess == "break" || forceBreak {
break
}
@ -275,6 +291,7 @@ func (c *Controller) FindAlerts(gctx *gin.Context) {
gctx.String(http.StatusOK, "")
return
}
gctx.JSON(http.StatusOK, data)
}
@ -282,21 +299,25 @@ func (c *Controller) FindAlerts(gctx *gin.Context) {
func (c *Controller) FindAlertByID(gctx *gin.Context) {
alertIDStr := gctx.Param("alert_id")
alertID, err := strconv.Atoi(alertIDStr)
if err != nil {
gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"})
return
}
result, err := c.DBClient.GetAlertByID(alertID)
if err != nil {
c.HandleDBErrors(gctx, err)
return
}
data := FormatOneAlert(result)
if gctx.Request.Method == http.MethodHead {
gctx.String(http.StatusOK, "")
return
}
gctx.JSON(http.StatusOK, data)
}
@ -316,15 +337,14 @@ func (c *Controller) DeleteAlertByID(gctx *gin.Context) {
gctx.JSON(http.StatusBadRequest, gin.H{"message": "alert_id must be valid integer"})
return
}
err = c.DBClient.DeleteAlertByID(decisionID)
if err != nil {
c.HandleDBErrors(gctx, err)
return
}
deleteAlertResp := models.DeleteAlertsResponse{
NbDeleted: "1",
}
deleteAlertResp := models.DeleteAlertsResponse{NbDeleted: "1"}
gctx.JSON(http.StatusOK, deleteAlertResp)
}
@ -336,15 +356,17 @@ func (c *Controller) DeleteAlerts(gctx *gin.Context) {
gctx.JSON(http.StatusForbidden, gin.H{"message": fmt.Sprintf("access forbidden from this IP (%s)", incomingIP)})
return
}
var err error
nbDeleted, err := c.DBClient.DeleteAlertWithFilter(gctx.Request.URL.Query())
if err != nil {
c.HandleDBErrors(gctx, err)
return
}
deleteAlertsResp := models.DeleteAlertsResponse{
NbDeleted: strconv.Itoa(nbDeleted),
}
gctx.JSON(http.StatusOK, deleteAlertsResp)
}
@ -355,5 +377,6 @@ func networksContainIP(networks []net.IPNet, ip string) bool {
return true
}
}
return false
}

View file

@ -61,8 +61,10 @@ func New(cfg *ControllerV1Config) (*Controller, error) {
TrustedIPs: cfg.TrustedIPs,
}
v1.Middlewares, err = middlewares.NewMiddlewares(cfg.DbClient)
if err != nil {
return v1, err
}
return v1, nil
}

View file

@ -33,6 +33,7 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision {
}
results = append(results, &decision)
}
return results
}
@ -44,12 +45,14 @@ func (c *Controller) GetDecision(gctx *gin.Context) {
bouncerInfo, err := getBouncerFromContext(gctx)
if err != nil {
gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"})
return
}
data, err = c.DBClient.QueryDecisionWithFilter(gctx.Request.URL.Query())
if err != nil {
c.HandleDBErrors(gctx, err)
return
}
@ -64,6 +67,7 @@ func (c *Controller) GetDecision(gctx *gin.Context) {
if gctx.Request.Method == http.MethodHead {
gctx.String(http.StatusOK, "")
return
}
@ -77,19 +81,21 @@ func (c *Controller) GetDecision(gctx *gin.Context) {
}
func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
var err error
decisionIDStr := gctx.Param("decision_id")
decisionID, err := strconv.Atoi(decisionIDStr)
if err != nil {
gctx.JSON(http.StatusBadRequest, gin.H{"message": "decision_id must be valid integer"})
return
}
nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID)
if err != nil {
c.HandleDBErrors(gctx, err)
return
}
// transform deleted decisions to be sendable to capi
deletedDecisions := FormatDecisions(deletedFromDB)
@ -105,12 +111,13 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
}
func (c *Controller) DeleteDecisions(gctx *gin.Context) {
var err error
nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionsWithFilter(gctx.Request.URL.Query())
if err != nil {
c.HandleDBErrors(gctx, err)
return
}
// transform deleted decisions to be sendable to capi
deletedDecisions := FormatDecisions(deletedFromDB)
@ -121,6 +128,7 @@ func (c *Controller) DeleteDecisions(gctx *gin.Context) {
deleteDecisionResp := models.DeleteDecisionResponse{
NbDeleted: nbDeleted,
}
gctx.JSON(http.StatusOK, deleteDecisionResp)
}
@ -147,6 +155,7 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun
results := FormatDecisions(data)
for _, decision := range results {
decisionJSON, _ := json.Marshal(decision)
if needComma {
//respBuffer.Write([]byte(","))
gctx.Writer.Write([]byte(","))
@ -158,6 +167,7 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun
_, err := gctx.Writer.Write(decisionJSON)
if err != nil {
gctx.Writer.Flush()
return err
}
//respBuffer.Reset()
@ -166,9 +176,11 @@ func writeStartupDecisions(gctx *gin.Context, filters map[string][]string, dbFun
log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId)
if len(data) < limit {
gctx.Writer.Flush()
break
}
}
return nil
}
@ -195,6 +207,7 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul
results := FormatDecisions(data)
for _, decision := range results {
decisionJSON, _ := json.Marshal(decision)
if needComma {
//respBuffer.Write([]byte(","))
gctx.Writer.Write([]byte(","))
@ -206,6 +219,7 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul
_, err := gctx.Writer.Write(decisionJSON)
if err != nil {
gctx.Writer.Flush()
return err
}
//respBuffer.Reset()
@ -214,9 +228,11 @@ func writeDeltaDecisions(gctx *gin.Context, filters map[string][]string, lastPul
log.Debugf("startup: %d decisions returned (limit: %d, lastid: %d)", len(data), limit, lastId)
if len(data) < limit {
gctx.Writer.Flush()
break
}
}
return nil
}
@ -231,13 +247,12 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B
// if the blocker just started, return all decisions
if val, ok := gctx.Request.URL.Query()["startup"]; ok && val[0] == "true" {
// Active decisions
err := writeStartupDecisions(gctx, filters, c.DBClient.QueryAllDecisionsWithFilters)
if err != nil {
log.Errorf("failed sending new decisions for startup: %v", err)
gctx.Writer.Write([]byte(`], "deleted": []}`))
gctx.Writer.Flush()
return err
}
@ -248,6 +263,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B
log.Errorf("failed sending expired decisions for startup: %v", err)
gctx.Writer.Write([]byte(`]}`))
gctx.Writer.Flush()
return err
}
@ -259,6 +275,7 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B
log.Errorf("failed sending new decisions for delta: %v", err)
gctx.Writer.Write([]byte(`], "deleted": []}`))
gctx.Writer.Flush()
return err
}
@ -270,18 +287,21 @@ func (c *Controller) StreamDecisionChunked(gctx *gin.Context, bouncerInfo *ent.B
log.Errorf("failed sending expired decisions for delta: %v", err)
gctx.Writer.Write([]byte(`]}`))
gctx.Writer.Flush()
return err
}
gctx.Writer.Write([]byte(`]}`))
gctx.Writer.Flush()
}
return nil
}
func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *ent.Bouncer, streamStartTime time.Time, filters map[string][]string) error {
var data []*ent.Decision
var err error
ret := make(map[string][]*models.Decision, 0)
ret["new"] = []*models.Decision{}
ret["deleted"] = []*models.Decision{}
@ -292,6 +312,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
if err != nil {
log.Errorf("failed querying decisions: %v", err)
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return err
}
//data = KeepLongestDecision(data)
@ -302,11 +323,14 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
if err != nil {
log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return err
}
ret["deleted"] = FormatDecisions(data)
gctx.JSON(http.StatusOK, ret)
return nil
}
}
@ -316,6 +340,7 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
if err != nil {
log.Errorf("unable to query new decision for '%s' : %v", bouncerInfo.Name, err)
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return err
}
//data = KeepLongestDecision(data)

View file

@ -8,7 +8,6 @@ import (
)
func (c *Controller) HeartBeat(gctx *gin.Context) {
claims := jwt.ExtractClaims(gctx)
// TBD: use defined rather than hardcoded key to find back owner
machineID := claims["id"].(string)
@ -17,5 +16,6 @@ func (c *Controller) HeartBeat(gctx *gin.Context) {
c.HandleDBErrors(gctx, err)
return
}
gctx.Status(http.StatusOK)
}

View file

@ -11,19 +11,19 @@ import (
)
func (c *Controller) CreateMachine(gctx *gin.Context) {
var err error
var input models.WatcherRegistrationRequest
if err = gctx.ShouldBindJSON(&input); err != nil {
if err := gctx.ShouldBindJSON(&input); err != nil {
gctx.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
return
}
if err = input.Validate(strfmt.Default); err != nil {
if err := input.Validate(strfmt.Default); err != nil {
c.HandleDBErrors(gctx, err)
return
}
_, err = c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType)
if err != nil {
if _, err := c.DBClient.CreateMachine(input.MachineID, input.Password, gctx.ClientIP(), false, false, types.PasswordAuthType); err != nil {
c.HandleDBErrors(gctx, err)
return
}

View file

@ -93,6 +93,7 @@ func PrometheusMachinesMiddleware() gin.HandlerFunc {
"method": c.Request.Method}).Inc()
}
}
c.Next()
}
}
@ -106,6 +107,7 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc {
"route": c.Request.URL.Path,
"method": c.Request.Method}).Inc()
}
c.Next()
}
}
@ -117,6 +119,7 @@ func PrometheusMiddleware() gin.HandlerFunc {
"route": c.Request.URL.Path,
"method": c.Request.Method}).Inc()
c.Next()
elapsed := time.Since(startTime)
LapiResponseTime.With(prometheus.Labels{"method": c.Request.Method, "endpoint": c.Request.URL.Path}).Observe(elapsed.Seconds())
}

View file

@ -9,9 +9,7 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/database/ent"
)
var (
bouncerContextKey = "bouncer_info"
)
const bouncerContextKey = "bouncer_info"
func getBouncerFromContext(ctx *gin.Context) (*ent.Bouncer, error) {
bouncerInterface, exist := ctx.Get(bouncerContextKey)

View file

@ -34,7 +34,9 @@ func GenerateAPIKey(n int) (string, error) {
if _, err := rand.Read(bytes); err != nil {
return "", err
}
encoded := base64.StdEncoding.EncodeToString(bytes)
// the '=' can cause issues on some bouncers
return strings.TrimRight(encoded, "="), nil
}
@ -67,6 +69,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
logger.Errorf("invalid client certificate: %s", err)
return nil
}
if err != nil {
logger.Error(err)
return nil
@ -88,7 +91,9 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
logger.Errorf("error generating mock api key: %s", err)
return nil
}
logger.Infof("Creating bouncer %s", bouncerName)
bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType)
if err != nil {
logger.Errorf("while creating bouncer db entry: %s", err)
@ -103,6 +108,7 @@ func (a *APIKey) authTLS(c *gin.Context, logger *log.Entry) *ent.Bouncer {
logger.Errorf("bouncer isn't allowed to auth by TLS")
return nil
}
return bouncer
}
@ -112,6 +118,7 @@ func (a *APIKey) authPlain(c *gin.Context, logger *log.Entry) *ent.Bouncer {
logger.Errorf("API key not found")
return nil
}
hashStr := HashSHA512(val[0])
bouncer, err := a.DbClient.SelectBouncer(hashStr)
@ -162,16 +169,19 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
}
if bouncer.IPAddress != c.ClientIP() && bouncer.IPAddress != "" {
log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, c.ClientIP(), bouncer.IPAddress)
if err := a.DbClient.UpdateBouncerIP(c.ClientIP(), bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return
}
}
@ -187,6 +197,7 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
logger.Errorf("failed to update bouncer version and type: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"})
c.Abort()
return
}
}

View file

@ -36,14 +36,15 @@ func PayloadFunc(data interface{}) jwt.MapClaims {
identityKey: &value.MachineID,
}
}
return jwt.MapClaims{}
}
func IdentityHandler(c *gin.Context) interface{} {
claims := jwt.ExtractClaims(c)
machineId := claims[identityKey].(string)
machineID := claims[identityKey].(string)
return &models.WatcherAuthRequest{
MachineID: &machineId,
MachineID: &machineID,
}
}
@ -67,6 +68,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
log.Error(err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
return nil, fmt.Errorf("while trying to validate client cert: %w", err)
}
@ -77,6 +79,7 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
}
ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
Where(machine.MachineId(ret.machineID)).
First(j.DbClient.CTX)