apiclient/apiserver: lint/2 (#2741)

This commit is contained in:
mmetc 2024-01-15 12:38:31 +01:00 committed by GitHub
parent 75d8ad9798
commit 48f011dc1c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 121 additions and 20 deletions

View file

@ -56,7 +56,7 @@ func (s *AlertsService) Add(ctx context.Context, alerts models.AddAlertsRequest)
return nil, nil, err
}
var addedIds models.AddAlertsResponse
addedIds := models.AddAlertsResponse{}
resp, err := s.client.Do(ctx, req, &addedIds)
if err != nil {

View file

@ -50,6 +50,7 @@ func (h *HeartBeatService) StartHeartBeat(ctx context.Context, t *tomb.Tomb) {
log.Errorf("heartbeat unexpected return code: %d", resp.Response.StatusCode)
continue
}
if !ok {
log.Errorf("heartbeat returned false")
continue

View file

@ -651,6 +651,7 @@ func (a *apic) PullTop(forcePull bool) error {
}
addCounters, deleteCounters := makeAddAndDeleteCounters()
// process deleted decisions
nbDeleted, err := a.HandleDeletedDecisionsV3(data.Deleted, deleteCounters)
if err != nil {

View file

@ -38,9 +38,10 @@ func FormatDecisions(decisions []*ent.Decision) []*models.Decision {
}
func (c *Controller) GetDecision(gctx *gin.Context) {
var err error
var results []*models.Decision
var data []*ent.Decision
var (
results []*models.Decision
data []*ent.Decision
)
bouncerInfo, err := getBouncerFromContext(gctx)
if err != nil {
@ -89,6 +90,7 @@ func (c *Controller) DeleteDecisionById(gctx *gin.Context) {
return
}
nbDeleted, deletedFromDB, err := c.DBClient.SoftDeleteDecisionByID(decisionID)
if err != nil {
c.HandleDBErrors(gctx, err)
@ -351,10 +353,13 @@ func (c *Controller) StreamDecisionNonChunked(gctx *gin.Context, bouncerInfo *en
if err != nil {
log.Errorf("unable to query expired decision for '%s' : %v", bouncerInfo.Name, err)
gctx.JSON(http.StatusInternalServerError, gin.H{"message": err.Error()})
return err
}
ret["deleted"] = FormatDecisions(data)
gctx.JSON(http.StatusOK, ret)
return nil
}
@ -362,9 +367,11 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
var err error
streamStartTime := time.Now().UTC()
bouncerInfo, err := getBouncerFromContext(gctx)
if err != nil {
gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"})
return
}
@ -372,6 +379,7 @@ func (c *Controller) StreamDecision(gctx *gin.Context) {
//For HEAD, just return as the bouncer won't get a body anyway, so no need to query the db
//We also don't update the last pull time, as it would mess with the delta sent on the next request (if done without startup=true)
gctx.String(http.StatusOK, "")
return
}

View file

@ -115,6 +115,7 @@ func PrometheusBouncersMiddleware() gin.HandlerFunc {
func PrometheusMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
startTime := time.Now()
LapiRouteHits.With(prometheus.Labels{
"route": c.Request.URL.Path,
"method": c.Request.Method}).Inc()

View file

@ -203,7 +203,6 @@ func (a *APIKey) MiddlewareFunc() gin.HandlerFunc {
}
c.Set(bouncerContextKey, bouncer)
c.Next()
}
}

View file

@ -43,6 +43,7 @@ func PayloadFunc(data interface{}) jwt.MapClaims {
func IdentityHandler(c *gin.Context) interface{} {
claims := jwt.ExtractClaims(c)
machineID := claims[identityKey].(string)
return &models.WatcherAuthRequest{
MachineID: &machineID,
}
@ -93,9 +94,12 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
"ip": c.ClientIP(),
"cn": extractedCN,
}).Errorf("error generating password: %s", err)
return nil, fmt.Errorf("error generating password")
}
password := strfmt.Password(pwd)
ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType)
if err != nil {
return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err)
@ -114,27 +118,33 @@ func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
}{
Scenarios: []string{},
}
err = c.ShouldBindJSON(&loginInput)
if err != nil {
return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err)
}
ret.scenariosInput = loginInput.Scenarios
return &ret, nil
}
func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
var loginInput models.WatcherAuthRequest
var err error
var (
loginInput models.WatcherAuthRequest
err error
)
ret := authInput{}
if err = c.ShouldBindJSON(&loginInput); err != nil {
return nil, fmt.Errorf("missing: %w", err)
}
if err = loginInput.Validate(strfmt.Default); err != nil {
return nil, err
}
ret.machineID = *loginInput.MachineID
password := *loginInput.Password
ret.scenariosInput = loginInput.Scenarios
@ -168,8 +178,10 @@ func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
}
func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
var err error
var auth *authInput
var (
err error
auth *authInput
)
if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
auth, err = j.authTLS(c)
@ -193,6 +205,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
scenarios += "," + scenario
}
}
err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID)
if err != nil {
log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err)
@ -210,6 +223,7 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
if auth.clientMachine.IpAddress != c.ClientIP() && auth.clientMachine.IpAddress != "" {
log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, c.ClientIP(), auth.clientMachine.IpAddress)
err = j.DbClient.UpdateMachineIP(c.ClientIP(), auth.clientMachine.ID)
if err != nil {
log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
@ -228,10 +242,10 @@ func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
log.Errorf("bad user agent from : %s", c.ClientIP())
return nil, jwt.ErrFailedAuthentication
}
return &models.WatcherAuthRequest{
MachineID: &auth.machineID,
}, nil
}
func Authorizator(data interface{}, c *gin.Context) bool {

View file

@ -18,5 +18,6 @@ func NewMiddlewares(dbClient *database.Client) (*Middlewares, error) {
}
ret.APIKey = NewAPIKey(dbClient)
return ret, nil
}

View file

@ -36,32 +36,40 @@ func (ta *TLSAuth) ocspQuery(server string, cert *x509.Certificate, issuer *x509
ta.logger.Errorf("TLSAuth: error creating OCSP request: %s", err)
return nil, err
}
httpRequest, err := http.NewRequest(http.MethodPost, server, bytes.NewBuffer(req))
if err != nil {
ta.logger.Error("TLSAuth: cannot create HTTP request for OCSP")
return nil, err
}
ocspURL, err := url.Parse(server)
if err != nil {
ta.logger.Error("TLSAuth: cannot parse OCSP URL")
return nil, err
}
httpRequest.Header.Add("Content-Type", "application/ocsp-request")
httpRequest.Header.Add("Accept", "application/ocsp-response")
httpRequest.Header.Add("host", ocspURL.Host)
httpClient := &http.Client{}
httpResponse, err := httpClient.Do(httpRequest)
if err != nil {
ta.logger.Error("TLSAuth: cannot send HTTP request to OCSP")
return nil, err
}
defer httpResponse.Body.Close()
output, err := io.ReadAll(httpResponse.Body)
if err != nil {
ta.logger.Error("TLSAuth: cannot read HTTP response from OCSP")
return nil, err
}
ocspResponse, err := ocsp.ParseResponseForCert(output, cert, issuer)
return ocspResponse, err
}
@ -72,10 +80,12 @@ func (ta *TLSAuth) isExpired(cert *x509.Certificate) bool {
ta.logger.Errorf("TLSAuth: client certificate is expired (NotAfter: %s)", cert.NotAfter.UTC())
return true
}
if cert.NotBefore.UTC().After(now) {
ta.logger.Errorf("TLSAuth: client certificate is not yet valid (NotBefore: %s)", cert.NotBefore.UTC())
return true
}
return false
}
@ -84,12 +94,14 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat
ta.logger.Infof("TLSAuth: no OCSP Server present in client certificate, skipping OCSP verification")
return false, nil
}
for _, server := range cert.OCSPServer {
ocspResponse, err := ta.ocspQuery(server, cert, issuer)
if err != nil {
ta.logger.Errorf("TLSAuth: error querying OCSP server %s: %s", server, err)
continue
}
switch ocspResponse.Status {
case ocsp.Good:
return false, nil
@ -100,7 +112,9 @@ func (ta *TLSAuth) isOCSPRevoked(cert *x509.Certificate, issuer *x509.Certificat
continue
}
}
log.Infof("Could not get any valid OCSP response, assuming the cert is revoked")
return true, nil
}
@ -109,24 +123,29 @@ func (ta *TLSAuth) isCRLRevoked(cert *x509.Certificate) (bool, error) {
ta.logger.Warn("no crl_path, skipping CRL check")
return false, nil
}
crlContent, err := os.ReadFile(ta.CrlPath)
if err != nil {
ta.logger.Warnf("could not read CRL file, skipping check: %s", err)
return false, nil
}
crl, err := x509.ParseCRL(crlContent)
if err != nil {
ta.logger.Warnf("could not parse CRL file, skipping check: %s", err)
return false, nil
}
if crl.HasExpired(time.Now().UTC()) {
ta.logger.Warn("CRL has expired, will still validate the cert against it.")
}
for _, revoked := range crl.TBSCertList.RevokedCertificates {
if revoked.SerialNumber.Cmp(cert.SerialNumber) == 0 {
return true, fmt.Errorf("client certificate is revoked by CRL")
}
}
return false, nil
}
@ -143,6 +162,7 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (
} else {
ta.logger.Tracef("TLSAuth: no cached value for cert %s", sn)
}
revoked, err := ta.isOCSPRevoked(cert, issuer)
if err != nil {
ta.revokationCache[sn] = cacheEntry{
@ -150,22 +170,27 @@ func (ta *TLSAuth) isRevoked(cert *x509.Certificate, issuer *x509.Certificate) (
err: err,
timestamp: time.Now().UTC(),
}
return true, err
}
if revoked {
ta.revokationCache[sn] = cacheEntry{
revoked: revoked,
err: err,
timestamp: time.Now().UTC(),
}
return true, nil
}
revoked, err = ta.isCRLRevoked(cert)
ta.revokationCache[sn] = cacheEntry{
revoked: revoked,
err: err,
timestamp: time.Now().UTC(),
}
return revoked, err
}
@ -173,6 +198,7 @@ func (ta *TLSAuth) isInvalid(cert *x509.Certificate, issuer *x509.Certificate) (
if ta.isExpired(cert) {
return true, nil
}
revoked, err := ta.isRevoked(cert, issuer)
if err != nil {
//Fail securely, if we can't check the revocation status, let's consider the cert invalid
@ -189,24 +215,30 @@ func (ta *TLSAuth) SetAllowedOu(allowedOus []string) error {
if ou == "" {
return fmt.Errorf("empty ou isn't allowed")
}
//drop & warn on duplicate ou
ok := true
for _, validOu := range ta.AllowedOUs {
if validOu == ou {
ta.logger.Warningf("dropping duplicate ou %s", ou)
ok = false
}
}
if ok {
ta.AllowedOUs = append(ta.AllowedOUs, ou)
}
}
return nil
}
func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
//Checks cert validity, Returns true + CN if client cert matches requested OU
var clientCert *x509.Certificate
if c.Request.TLS == nil || len(c.Request.TLS.PeerCertificates) == 0 {
//do not error if it's not TLS or there are no peer certs
return false, "", nil
@ -215,6 +247,7 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
if len(c.Request.TLS.VerifiedChains) > 0 {
validOU := false
clientCert = c.Request.TLS.VerifiedChains[0][0]
for _, ou := range clientCert.Subject.OrganizationalUnit {
for _, allowedOu := range ta.AllowedOUs {
if allowedOu == ou {
@ -223,21 +256,27 @@ func (ta *TLSAuth) ValidateCert(c *gin.Context) (bool, string, error) {
}
}
}
if !validOU {
return false, "", fmt.Errorf("client certificate OU (%v) doesn't match expected OU (%v)",
clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
}
revoked, err := ta.isInvalid(clientCert, c.Request.TLS.VerifiedChains[0][1])
if err != nil {
ta.logger.Errorf("TLSAuth: error checking if client certificate is revoked: %s", err)
return false, "", fmt.Errorf("could not check for client certification revokation status: %w", err)
}
if revoked {
return false, "", fmt.Errorf("client certificate is revoked")
}
ta.logger.Debugf("client OU %v is allowed vs required OU %v", clientCert.Subject.OrganizationalUnit, ta.AllowedOUs)
return true, clientCert.Subject.CommonName, nil
}
return false, "", fmt.Errorf("no verified cert in request")
}
@ -248,9 +287,11 @@ func NewTLSAuth(allowedOus []string, crlPath string, cacheExpiration time.Durati
CrlPath: crlPath,
logger: logger,
}
err := ta.SetAllowedOu(allowedOus)
if err != nil {
return nil, err
}
return ta, nil
}

View file

@ -205,12 +205,15 @@ func (p *Papi) PullOnce(since time.Time, sync bool) error {
reversedEvents := reverse(events) //PAPI sends events in the reverse order, which is not an issue when pulling them in real time, but here we need the correct order
eventsCount := len(events)
p.Logger.Infof("received %d events", eventsCount)
for i, event := range reversedEvents {
if err := p.handleEvent(event, sync); err != nil {
p.Logger.WithField("request-id", event.RequestId).Errorf("failed to handle event: %s", err)
}
p.Logger.Debugf("handled event %d/%d", i, eventsCount)
}
p.Logger.Debugf("finished handling events")
//Don't update the timestamp in DB, as a "real" LAPI might be running
//Worst case, crowdsec will receive a few duplicated events and will discard them
@ -223,16 +226,19 @@ func (p *Papi) Pull() error {
p.Logger.Infof("Starting Polling API Pull")
lastTimestamp := time.Time{}
lastTimestampStr, err := p.DBClient.GetConfigItem(PapiPullKey)
if err != nil {
p.Logger.Warningf("failed to get last timestamp for papi pull: %s", err)
}
//value doesn't exist, it's first time we're pulling
if lastTimestampStr == nil {
binTime, err := lastTimestamp.MarshalText()
if err != nil {
return fmt.Errorf("failed to marshal last timestamp: %w", err)
}
if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
p.Logger.Errorf("error setting papi pull last key: %s", err)
} else {
@ -245,10 +251,12 @@ func (p *Papi) Pull() error {
}
p.Logger.Infof("Starting PAPI pull (since:%s)", lastTimestamp)
for event := range p.Client.Start(lastTimestamp) {
logger := p.Logger.WithField("request-id", event.RequestId)
//update last timestamp in database
newTime := time.Now().UTC()
binTime, err := newTime.MarshalText()
if err != nil {
return fmt.Errorf("failed to marshal last timestamp: %w", err)
@ -262,11 +270,11 @@ func (p *Papi) Pull() error {
if err := p.DBClient.SetConfigItem(PapiPullKey, string(binTime)); err != nil {
return fmt.Errorf("failed to update last timestamp: %w", err)
} else {
}
logger.Debugf("set last timestamp to %s", newTime)
}
}
return nil
}
@ -274,6 +282,7 @@ func (p *Papi) SyncDecisions() error {
defer trace.CatchPanic("lapi/syncDecisionsToCAPI")
var cache models.DecisionsDeleteRequest
ticker := time.NewTicker(p.SyncInterval)
p.Logger.Infof("Start decisions sync to CrowdSec Central API (interval: %s)", p.SyncInterval)
@ -281,10 +290,13 @@ func (p *Papi) SyncDecisions() error {
select {
case <-p.syncTomb.Dying(): // if one apic routine is dying, do we kill the others?
p.Logger.Infof("sync decisions tomb is dying, sending cache (%d elements) before exiting", len(cache))
if len(cache) == 0 {
return nil
}
go p.SendDeletedDecisions(&cache)
return nil
case <-ticker.C:
if len(cache) > 0 {
@ -293,15 +305,19 @@ func (p *Papi) SyncDecisions() error {
cache = make([]models.DecisionsDeleteRequestItem, 0)
p.mu.Unlock()
p.Logger.Infof("sync decisions: %d deleted decisions to push", len(cacheCopy))
go p.SendDeletedDecisions(&cacheCopy)
}
case deletedDecisions := <-p.Channels.DeleteDecisionChannel:
if (p.consoleConfig.ShareManualDecisions != nil && *p.consoleConfig.ShareManualDecisions) || (p.consoleConfig.ConsoleManagement != nil && *p.consoleConfig.ConsoleManagement) {
var tmpDecisions []models.DecisionsDeleteRequestItem
p.Logger.Debugf("%d decisions deletion to add in cache", len(deletedDecisions))
for _, decision := range deletedDecisions {
tmpDecisions = append(tmpDecisions, models.DecisionsDeleteRequestItem(decision.UUID))
}
p.mu.Lock()
cache = append(cache, tmpDecisions...)
p.mu.Unlock()
@ -311,33 +327,42 @@ func (p *Papi) SyncDecisions() error {
}
func (p *Papi) SendDeletedDecisions(cacheOrig *models.DecisionsDeleteRequest) {
var cache []models.DecisionsDeleteRequestItem = *cacheOrig
var send models.DecisionsDeleteRequest
var (
cache []models.DecisionsDeleteRequestItem = *cacheOrig
send models.DecisionsDeleteRequest
)
bulkSize := 50
pageStart := 0
pageEnd := bulkSize
for {
if pageEnd >= len(cache) {
send = cache[pageStart:]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _, err := p.apiClient.DecisionDelete.Add(ctx, &send)
if err != nil {
p.Logger.Errorf("sending deleted decisions to central API: %s", err)
return
}
break
}
send = cache[pageStart:pageEnd]
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _, err := p.apiClient.DecisionDelete.Add(ctx, &send)
if err != nil {
//we log it here as well, because the return value of func might be discarded
p.Logger.Errorf("sending deleted decisions to central API: %s", err)
}
pageStart += bulkSize
pageEnd += bulkSize
}

View file

@ -40,17 +40,18 @@ type forcePull struct {
func DecisionCmd(message *Message, p *Papi, sync bool) error {
switch message.Header.OperationCmd {
case "delete":
data, err := json.Marshal(message.Data)
if err != nil {
return err
}
UUIDs := make([]string, 0)
deleteDecisionMsg := deleteDecisions{
Decisions: make([]string, 0),
}
if err := json.Unmarshal(data, &deleteDecisionMsg); err != nil {
return fmt.Errorf("message for '%s' contains bad data format: %s", message.Header.OperationType, err)
return fmt.Errorf("message for '%s' contains bad data format: %w", message.Header.OperationType, err)
}
UUIDs = append(UUIDs, deleteDecisionMsg.Decisions...)
@ -59,10 +60,13 @@ func DecisionCmd(message *Message, p *Papi, sync bool) error {
filter := make(map[string][]string)
filter["uuid"] = UUIDs
_, deletedDecisions, err := p.DBClient.SoftDeleteDecisionsWithFilter(filter)
if err != nil {
return fmt.Errorf("unable to delete decisions %+v : %s", UUIDs, err)
return fmt.Errorf("unable to delete decisions %+v: %w", UUIDs, err)
}
decisions := make([]*models.Decision, 0)
for _, deletedDecision := range deletedDecisions {
log.Infof("Decision from '%s' for '%s' (%s) has been deleted", deletedDecision.Origin, deletedDecision.Value, deletedDecision.Type)
dec := &models.Decision{
@ -92,6 +96,7 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
if err != nil {
return err
}
alert := &models.Alert{}
if err := json.Unmarshal(data, alert); err != nil {
@ -105,10 +110,12 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
log.Warnf("Alert %d has no StartAt, setting it to now", alert.ID)
alert.StartAt = ptr.Of(time.Now().UTC().Format(time.RFC3339))
}
if alert.StopAt == nil || *alert.StopAt == "" {
log.Warnf("Alert %d has no StopAt, setting it to now", alert.ID)
alert.StopAt = ptr.Of(time.Now().UTC().Format(time.RFC3339))
}
alert.EventsCount = ptr.Of(int32(0))
alert.Capacity = ptr.Of(int32(0))
alert.Leakspeed = ptr.Of("")
@ -128,12 +135,14 @@ func AlertCmd(message *Message, p *Papi, sync bool) error {
alert.Source.Scope = ptr.Of(types.ConsoleOrigin)
alert.Source.Value = &message.Header.Source.User
}
alert.Scenario = &message.Header.Message
for _, decision := range alert.Decisions {
if *decision.Scenario == "" {
decision.Scenario = &message.Header.Message
}
log.Infof("Adding decision for '%s' with UUID: %s", *decision.Value, decision.UUID)
}
@ -157,6 +166,7 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
log.Infof("Ignoring management command from PAPI in sync mode")
return nil
}
switch message.Header.OperationCmd {
case "reauth":
log.Infof("Received reauth command from PAPI, resetting token")
@ -187,12 +197,12 @@ func ManagementCmd(message *Message, p *Papi, sync bool) error {
Duration: &forcePullMsg.Blocklist.Duration,
}, true)
if err != nil {
return fmt.Errorf("failed to force pull operation: %s", err)
return fmt.Errorf("failed to force pull operation: %w", err)
}
}
default:
return fmt.Errorf("unknown command '%s' for operation type '%s'", message.Header.OperationCmd, message.Header.OperationType)
}
return nil
}