apiclient: fix http roundtrip (clone body also) (#1758)

* apiclient: fix http roundtrip (clone body also)
This commit is contained in:
he2ss 2022-12-14 16:42:46 +01:00 committed by GitHub
parent fe23da6e0c
commit 579cecde04
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 17 deletions

View file

@ -78,7 +78,7 @@ func (t *APIKeyTransport) transport() http.RoundTripper {
type JWTTransport struct {
MachineID *string
Password *strfmt.Password
token string
Token string
Expiration time.Time
Scenarios []string
URL *url.URL
@ -88,6 +88,7 @@ type JWTTransport struct {
// It will default to http.DefaultTransport if nil.
Transport http.RoundTripper
UpdateScenario func() ([]string, error)
NbRetry int
}
func (t *JWTTransport) refreshJwtToken() error {
@ -161,45 +162,63 @@ func (t *JWTTransport) refreshJwtToken() error {
if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
return errors.Wrap(err, "unable to parse jwt expiration")
}
t.token = response.Token
t.Token = response.Token
log.Debugf("token %s will expire on %s", t.token, t.Expiration.String())
log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
return nil
}
// RoundTrip implements the RoundTripper interface.
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
if t.NbRetry > 1 {
t.NbRetry = 0
return nil, fmt.Errorf("unable to refresh JWT token multiple times")
}
if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
if err := t.refreshJwtToken(); err != nil {
return nil, err
}
}
// We must make a copy of the Request so
// that we don't modify the Request we were given. This is required by the
// specification of http.RoundTripper.
req = cloneRequest(req)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
log.Debugf("req-jwt: %s %s", req.Method, req.URL.String())
if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("req-jwt: %s", string(dump))
}
if t.UserAgent != "" {
req.Header.Add("User-Agent", t.UserAgent)
}
// We must make a copy of the Request so
// that we don't modify the Request we were given. This is required by the
// specification of http.RoundTripper.
clonedReq := cloneRequest(req)
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
if log.GetLevel() >= log.TraceLevel {
//requestToDump := cloneRequest(req)
dump, _ := httputil.DumpRequest(req, true)
log.Tracef("req-jwt: %s", string(dump))
}
// Make the HTTP request.
resp, err := t.transport().RoundTrip(req)
if log.GetLevel() >= log.TraceLevel {
dump, _ := httputil.DumpResponse(resp, true)
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
}
if err != nil || resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized {
if err != nil {
/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
t.token = ""
t.Token = ""
return resp, errors.Wrapf(err, "performing jwt auth")
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
t.Token = ""
t.NbRetry++
return t.RoundTrip(clonedReq)
}
t.NbRetry = 0
log.Debugf("resp-jwt: %d", resp.StatusCode)
return resp, nil
}
@ -225,5 +244,12 @@ func cloneRequest(r *http.Request) *http.Request {
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
if r.Body != nil {
var b bytes.Buffer
b.ReadFrom(r.Body)
r.Body = io.NopCloser(&b)
r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
}
return r2
}

View file

@ -234,5 +234,5 @@ func TestWatcherEnroll(t *testing.T) {
}
_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
assert.Contains(t, err.Error(), "the attachment key provided is not valid")
assert.Contains(t, err.Error(), "unable to refresh JWT token multiple times", "got %s", err.Error())
}

View file

@ -51,6 +51,7 @@ func NewClient(config *Config) (*ApiClient, error) {
UserAgent: config.UserAgent,
VersionPrefix: config.VersionPrefix,
UpdateScenario: config.UpdateScenario,
NbRetry: 0,
}
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
if Cert != nil {