apiclient: fix http roundtrip (clone body also) (#1758)
* apiclient: fix http roundtrip (clone body also)
This commit is contained in:
parent
fe23da6e0c
commit
579cecde04
|
@ -78,7 +78,7 @@ func (t *APIKeyTransport) transport() http.RoundTripper {
|
||||||
type JWTTransport struct {
|
type JWTTransport struct {
|
||||||
MachineID *string
|
MachineID *string
|
||||||
Password *strfmt.Password
|
Password *strfmt.Password
|
||||||
token string
|
Token string
|
||||||
Expiration time.Time
|
Expiration time.Time
|
||||||
Scenarios []string
|
Scenarios []string
|
||||||
URL *url.URL
|
URL *url.URL
|
||||||
|
@ -88,6 +88,7 @@ type JWTTransport struct {
|
||||||
// It will default to http.DefaultTransport if nil.
|
// It will default to http.DefaultTransport if nil.
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
UpdateScenario func() ([]string, error)
|
UpdateScenario func() ([]string, error)
|
||||||
|
NbRetry int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *JWTTransport) refreshJwtToken() error {
|
func (t *JWTTransport) refreshJwtToken() error {
|
||||||
|
@ -161,45 +162,63 @@ func (t *JWTTransport) refreshJwtToken() error {
|
||||||
if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
|
if err := t.Expiration.UnmarshalText([]byte(response.Expire)); err != nil {
|
||||||
return errors.Wrap(err, "unable to parse jwt expiration")
|
return errors.Wrap(err, "unable to parse jwt expiration")
|
||||||
}
|
}
|
||||||
t.token = response.Token
|
t.Token = response.Token
|
||||||
|
|
||||||
log.Debugf("token %s will expire on %s", t.token, t.Expiration.String())
|
log.Debugf("token %s will expire on %s", t.Token, t.Expiration.String())
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip implements the RoundTripper interface.
|
// RoundTrip implements the RoundTripper interface.
|
||||||
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (t *JWTTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
if t.token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
|
if t.NbRetry > 1 {
|
||||||
|
t.NbRetry = 0
|
||||||
|
return nil, fmt.Errorf("unable to refresh JWT token multiple times")
|
||||||
|
}
|
||||||
|
if t.Token == "" || t.Expiration.Add(-time.Minute).Before(time.Now().UTC()) {
|
||||||
if err := t.refreshJwtToken(); err != nil {
|
if err := t.refreshJwtToken(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We must make a copy of the Request so
|
|
||||||
// that we don't modify the Request we were given. This is required by the
|
|
||||||
// specification of http.RoundTripper.
|
|
||||||
req = cloneRequest(req)
|
|
||||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.token))
|
|
||||||
log.Debugf("req-jwt: %s %s", req.Method, req.URL.String())
|
|
||||||
if log.GetLevel() >= log.TraceLevel {
|
|
||||||
dump, _ := httputil.DumpRequest(req, true)
|
|
||||||
log.Tracef("req-jwt: %s", string(dump))
|
|
||||||
}
|
|
||||||
if t.UserAgent != "" {
|
if t.UserAgent != "" {
|
||||||
req.Header.Add("User-Agent", t.UserAgent)
|
req.Header.Add("User-Agent", t.UserAgent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We must make a copy of the Request so
|
||||||
|
// that we don't modify the Request we were given. This is required by the
|
||||||
|
// specification of http.RoundTripper.
|
||||||
|
clonedReq := cloneRequest(req)
|
||||||
|
|
||||||
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", t.Token))
|
||||||
|
|
||||||
|
if log.GetLevel() >= log.TraceLevel {
|
||||||
|
//requestToDump := cloneRequest(req)
|
||||||
|
dump, _ := httputil.DumpRequest(req, true)
|
||||||
|
log.Tracef("req-jwt: %s", string(dump))
|
||||||
|
}
|
||||||
|
|
||||||
// Make the HTTP request.
|
// Make the HTTP request.
|
||||||
resp, err := t.transport().RoundTrip(req)
|
resp, err := t.transport().RoundTrip(req)
|
||||||
if log.GetLevel() >= log.TraceLevel {
|
if log.GetLevel() >= log.TraceLevel {
|
||||||
dump, _ := httputil.DumpResponse(resp, true)
|
dump, _ := httputil.DumpResponse(resp, true)
|
||||||
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
|
log.Tracef("resp-jwt: %s (err:%v)", string(dump), err)
|
||||||
}
|
}
|
||||||
if err != nil || resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized {
|
if err != nil {
|
||||||
/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
|
/*we had an error (network error for example, or 401 because token is refused), reset the token ?*/
|
||||||
t.token = ""
|
t.Token = ""
|
||||||
return resp, errors.Wrapf(err, "performing jwt auth")
|
return resp, errors.Wrapf(err, "performing jwt auth")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||||
|
t.Token = ""
|
||||||
|
t.NbRetry++
|
||||||
|
return t.RoundTrip(clonedReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.NbRetry = 0
|
||||||
|
|
||||||
log.Debugf("resp-jwt: %d", resp.StatusCode)
|
log.Debugf("resp-jwt: %d", resp.StatusCode)
|
||||||
|
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,5 +244,12 @@ func cloneRequest(r *http.Request) *http.Request {
|
||||||
for k, s := range r.Header {
|
for k, s := range r.Header {
|
||||||
r2.Header[k] = append([]string(nil), s...)
|
r2.Header[k] = append([]string(nil), s...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.Body != nil {
|
||||||
|
var b bytes.Buffer
|
||||||
|
b.ReadFrom(r.Body)
|
||||||
|
r.Body = io.NopCloser(&b)
|
||||||
|
r2.Body = io.NopCloser(bytes.NewReader(b.Bytes()))
|
||||||
|
}
|
||||||
return r2
|
return r2
|
||||||
}
|
}
|
||||||
|
|
|
@ -234,5 +234,5 @@ func TestWatcherEnroll(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
|
_, err = client.Auth.EnrollWatcher(context.Background(), "badkey", "", []string{}, false)
|
||||||
assert.Contains(t, err.Error(), "the attachment key provided is not valid")
|
assert.Contains(t, err.Error(), "unable to refresh JWT token multiple times", "got %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,6 +51,7 @@ func NewClient(config *Config) (*ApiClient, error) {
|
||||||
UserAgent: config.UserAgent,
|
UserAgent: config.UserAgent,
|
||||||
VersionPrefix: config.VersionPrefix,
|
VersionPrefix: config.VersionPrefix,
|
||||||
UpdateScenario: config.UpdateScenario,
|
UpdateScenario: config.UpdateScenario,
|
||||||
|
NbRetry: 0,
|
||||||
}
|
}
|
||||||
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
|
tlsconfig := tls.Config{InsecureSkipVerify: InsecureSkipVerify}
|
||||||
if Cert != nil {
|
if Cert != nil {
|
||||||
|
|
Loading…
Reference in a new issue