diff --git a/cmd/crowdsec-cli/capi.go b/cmd/crowdsec-cli/capi.go index 11d4000d3..621305248 100644 --- a/cmd/crowdsec-cli/capi.go +++ b/cmd/crowdsec-cli/capi.go @@ -172,7 +172,7 @@ func NewCapiStatusCmd() *cobra.Command { } log.Infof("Loaded credentials from %s", csConfig.API.Server.OnlineClient.CredentialsFilePath) log.Infof("Trying to authenticate with username %s on %s", csConfig.API.Server.OnlineClient.Credentials.Login, apiurl) - _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) + _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) if err != nil { log.Fatalf("Failed to authenticate to Central API (CAPI) : %s", err) } diff --git a/cmd/crowdsec-cli/lapi.go b/cmd/crowdsec-cli/lapi.go index 97ecfed88..ae6d3f33f 100644 --- a/cmd/crowdsec-cli/lapi.go +++ b/cmd/crowdsec-cli/lapi.go @@ -63,7 +63,7 @@ func runLapiStatus(cmd *cobra.Command, args []string) error { } log.Infof("Loaded credentials from %s", csConfig.API.Client.CredentialsFilePath) log.Infof("Trying to authenticate with username %s on %s", login, apiurl) - _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) + _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) if err != nil { log.Fatalf("Failed to authenticate to Local API (LAPI) : %s", err) } else { diff --git a/cmd/crowdsec-cli/support.go b/cmd/crowdsec-cli/support.go index 2db31f781..e7110ae70 100644 --- a/cmd/crowdsec-cli/support.go +++ b/cmd/crowdsec-cli/support.go @@ -102,7 +102,6 @@ func collectFeatures() []byte { return w.Bytes() } - func collectOSInfo() ([]byte, error) { log.Info("Collecting OS info") info, err := osinfo.GetOSInfo() @@ -194,7 +193,7 @@ func collectAPIStatus(login string, password string, endpoint string, prefix str Scenarios: scenarios, } - _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) + _, _, err = Client.Auth.AuthenticateWatcher(context.Background(), t) if err != nil { return []byte(fmt.Sprintf("Could not authenticate to API: %s", err)) } else { @@ -277,7 +276,7 @@ cscli support dump -f /tmp/crowdsec-support.zip var err error var skipHub, skipDB, skipCAPI, skipLAPI, skipAgent bool infos := map[string][]byte{ - SUPPORT_VERSION_PATH: collectVersion(), + SUPPORT_VERSION_PATH: collectVersion(), SUPPORT_FEATURES_PATH: collectFeatures(), } diff --git a/cmd/crowdsec/output.go b/cmd/crowdsec/output.go index 5677297d6..efeab7720 100644 --- a/cmd/crowdsec/output.go +++ b/cmd/crowdsec/output.go @@ -97,13 +97,21 @@ func runOutput(input chan types.Event, overflow chan types.Event, buckets *leaky if err != nil { return errors.Wrapf(err, "new client api: %s", err) } - if _, err = Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + authResp, _, err := Client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &apiConfig.Login, Password: &password, Scenarios: scenarios, - }); err != nil { + }) + if err != nil { return errors.Wrapf(err, "authenticate watcher (%s)", apiConfig.Login) } + + if err := Client.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { + return errors.Wrap(err, "unable to parse jwt expiration") + } + + Client.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token + //start the heartbeat service log.Debugf("Starting HeartBeat service") Client.HeartBeat.StartHeartBeat(context.Background(), &outputsTomb) diff --git a/pkg/apiclient/auth.go b/pkg/apiclient/auth.go index 69f3d2bb0..48b971a06 100644 --- a/pkg/apiclient/auth.go +++ b/pkg/apiclient/auth.go @@ -232,8 +232,9 @@ func (t *JWTTransport) refreshJwtToken() error { func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) { // in a few occasions several goroutines will execute refreshJwtToken concurrently which is useless and will cause overload on CAPI // we use a mutex to avoid this + //We also bypass the refresh if we are requesting the login endpoint, as it does not require a token, and it leads to do 2 requests instead of one (refresh + actual login request) t.refreshTokenMutex.Lock() - if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) { + if req.URL.Path != "/"+t.VersionPrefix+"/watchers/login" && (t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC())) { if err := t.refreshJwtToken(); err != nil { t.refreshTokenMutex.Unlock() return nil, err diff --git a/pkg/apiclient/auth_service.go b/pkg/apiclient/auth_service.go index bf02738d2..26ad80c0c 100644 --- a/pkg/apiclient/auth_service.go +++ b/pkg/apiclient/auth_service.go @@ -51,18 +51,20 @@ func (s *AuthService) RegisterWatcher(ctx context.Context, registration models.W return resp, nil } -func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.WatcherAuthRequest) (*Response, error) { +func (s *AuthService) AuthenticateWatcher(ctx context.Context, auth models.WatcherAuthRequest) (models.WatcherAuthResponse, *Response, error) { + var authResp models.WatcherAuthResponse + u := fmt.Sprintf("%s/watchers/login", s.client.URLPrefix) req, err := s.client.NewRequest(http.MethodPost, u, &auth) if err != nil { - return nil, err + return authResp, nil, err } - resp, err := s.client.Do(ctx, req, nil) + resp, err := s.client.Do(ctx, req, &authResp) if err != nil { - return resp, err + return authResp, resp, err } - return resp, nil + return authResp, resp, nil } func (s *AuthService) EnrollWatcher(ctx context.Context, enrollKey string, name string, tags []string, overwrite bool) (*Response, error) { diff --git a/pkg/apiclient/auth_service_test.go b/pkg/apiclient/auth_service_test.go index 515b684cd..c844afe8e 100644 --- a/pkg/apiclient/auth_service_test.go +++ b/pkg/apiclient/auth_service_test.go @@ -60,7 +60,7 @@ func TestWatcherAuth(t *testing.T) { t.Fatalf("new api client: %s", err) } - _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + _, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &mycfg.MachineID, Password: &mycfg.Password, Scenarios: mycfg.Scenarios, @@ -84,7 +84,7 @@ func TestWatcherAuth(t *testing.T) { t.Fatalf("new api client: %s", err) } - _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + _, _, err = client.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ MachineID: &mycfg.MachineID, Password: &mycfg.Password, }) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index d6983b121..d4dd4f86a 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -201,18 +201,25 @@ func NewAPIC(config *csconfig.OnlineApiClientCfg, dbClient *database.Client, con // The watcher will be authenticated by the RoundTripper the first time it will call CAPI // Explicit authentication will provoke an useless supplementary call to CAPI - // scenarios, err := ret.FetchScenariosListFromDB() - // if err != nil { - // return ret, errors.Wrapf(err, "get scenario in db: %s", err) - // } + scenarios, err := ret.FetchScenariosListFromDB() + if err != nil { + return ret, errors.Wrapf(err, "get scenario in db: %s", err) + } - // if _, err = ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ - // MachineID: &config.Credentials.Login, - // Password: &password, - // Scenarios: scenarios, - // }); err != nil { - // return ret, errors.Wrapf(err, "authenticate watcher (%s)", config.Credentials.Login) - // } + authResp, _, err := ret.apiClient.Auth.AuthenticateWatcher(context.Background(), models.WatcherAuthRequest{ + MachineID: &config.Credentials.Login, + Password: &password, + Scenarios: scenarios, + }) + if err != nil { + return ret, errors.Wrapf(err, "authenticate watcher (%s)", config.Credentials.Login) + } + + if err := ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration.UnmarshalText([]byte(authResp.Expire)); err != nil { + return ret, errors.Wrap(err, "unable to parse jwt expiration") + } + + ret.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Token = authResp.Token return ret, err }