diff --git a/connection_test.go b/connection_test.go index 8ad3b75ec..638c3189f 100644 --- a/connection_test.go +++ b/connection_test.go @@ -167,27 +167,27 @@ var _ = Describe("Connection", func() { Expect(err).ToNot(HaveOccurred()) defer connection.Close() _, _, err = connection.Tokens() - Expect(transport.called).To(BeTrue()) + Expect(transport.called > 2).To(BeTrue()) // it means the retry was called Expect(err).To(HaveOccurred()) }) }) type TestTransport struct { - called bool + called int } func (t *TestTransport) RoundTrip(request *http.Request) (response *http.Response, err error) { - t.called = true + t.called++ header := http.Header{} header.Add("Content-type", "application/json") response = &http.Response{ - StatusCode: 401, + StatusCode: http.StatusInternalServerError, Header: header, - Body: gbytes.NewBuffer(), + Body: gbytes.BufferWithBytes([]byte("{}")), } return response, nil } func NewTestTransport() *TestTransport { - return &TestTransport{called: false} + return &TestTransport{called: 0} } diff --git a/go.mod b/go.mod index ac4482a2b..afee86afd 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/openshift-online/ocm-sdk-go go 1.12 require ( + github.com/cenkalti/backoff/v4 v4.0.0 github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/ghodss/yaml v1.0.0 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b diff --git a/go.sum b/go.sum index f4fa48928..7dd996bed 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/cenkalti/backoff/v4 v4.0.0 h1:6VeaLF9aI+MAUQ95106HwWzYZgJJpZ4stumjj6RFYAU= +github.com/cenkalti/backoff/v4 v4.0.0/go.mod h1:eEew/i+1Q6OrCDZh3WiXYv3+nJwBASZ8Bog/87DQnVg= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/token.go b/token.go index ae2953987..d599212bd 100644 --- a/token.go +++ b/token.go @@ -31,8 +31,8 @@ import ( "strings" "time" + "github.com/cenkalti/backoff/v4" "github.com/dgrijalva/jwt-go" - "github.com/openshift-online/ocm-sdk-go/internal" ) @@ -49,7 +49,30 @@ func (c *Connection) Tokens() (access, refresh string, err error) { // TokensContext returns the access and refresh tokens that is currently in use by the connection. // If it is necessary to request a new token because it wasn't requested yet, or because it is // expired, this method will do it and will return an error if it fails. +// The function will retry the operation in an exponential-backoff method. func (c *Connection) TokensContext(ctx context.Context) (access, refresh string, err error) { + operation := func() error { + c.logger.Debug(ctx, "trying to get tokens") + var code int + code, access, refresh, err = c.tokensContext(ctx) + if err != nil { + if code >= http.StatusInternalServerError { + c.logger.Error(ctx, "failed to get tokens, got http code %d, will attempt to retry. err: %v", code, err) + return err + } + c.logger.Debug(ctx, "failed to get tokens, got http code %d, will not attempt to retry. err: %v", code, err) + return nil + } + return nil + } + backoffMethod := backoff.NewExponentialBackOff() + backoffMethod.MaxElapsedTime = time.Second * 15 + // nolint + backoff.Retry(operation, backoffMethod) + return access, refresh, err +} + +func (c *Connection) tokensContext(ctx context.Context) (code int, access, refresh string, err error) { // We need to make sure that this method isn't execute concurrently, as we will be updating // multiple attributes of the connection: c.tokenMutex.Lock() @@ -88,7 +111,7 @@ func (c *Connection) TokensContext(ctx context.Context) (access, refresh string, // At this point we know that the access token is unavailable, expired or about to expire. // So we need to check if we can use the refresh token to request a new one. if c.refreshToken != nil && (!refreshExpires || refreshLeft >= 1*time.Minute) { - _, _, err = c.sendRefreshTokenForm(ctx) + code, _, err = c.sendRefreshTokenForm(ctx) if err != nil { return } @@ -100,7 +123,7 @@ func (c *Connection) TokensContext(ctx context.Context) (access, refresh string, // expire. So we need to check if we have other credentials that can be used to request a // new token, and use them. if c.haveCredentials() { - _, _, err = c.sendRequestTokenForm(ctx) + code, _, err = c.sendRequestTokenForm(ctx) if err != nil { return } @@ -118,7 +141,7 @@ func (c *Connection) TokensContext(ctx context.Context) (access, refresh string, "obtain a new token, so will try to use it anyhow", refreshLeft, ) - _, _, err = c.sendRefreshTokenForm(ctx) + code, _, err = c.sendRefreshTokenForm(ctx) if err != nil { return } @@ -282,6 +305,8 @@ func (c *Connection) sendTokenFormTimed(ctx context.Context, form url.Values) (c } defer response.Body.Close() + code = response.StatusCode + // Check that the response content type is JSON: err = c.checkContentType(response) if err != nil { @@ -311,7 +336,7 @@ func (c *Connection) sendTokenFormTimed(ctx context.Context, form url.Values) (c return } if response.StatusCode != http.StatusOK { - err = fmt.Errorf("token response status is: %s", response.Status) + err = fmt.Errorf("token response status code is '%d'", response.StatusCode) return } if result.TokenType != nil && *result.TokenType != "bearer" { diff --git a/token_test.go b/token_test.go index a8227a27c..699aa4c61 100644 --- a/token_test.go +++ b/token_test.go @@ -247,17 +247,19 @@ var _ = Describe("Tokens", func() { refreshToken := DefaultToken("Refresh", 10*time.Hour) // Configure the server: - oidServer.AppendHandlers( - ghttp.RespondWith( - http.StatusServiceUnavailable, - `Service unavailable`, - http.Header{ - "Content-Type": []string{ - "text/plain", + for i := 0; i < 100; i++ { // there are going to be several retries + oidServer.AppendHandlers( + ghttp.RespondWith( + http.StatusServiceUnavailable, + `Service unavailable`, + http.Header{ + "Content-Type": []string{ + "text/plain", + }, }, - }, - ), - ) + ), + ) + } // Create the connection: connection, err := NewConnectionBuilder(). @@ -287,7 +289,7 @@ var _ = Describe("Tokens", func() { // Configure the server: oidServer.AppendHandlers( ghttp.RespondWith( - http.StatusServiceUnavailable, + http.StatusBadRequest, content, http.Header{ "Content-Type": []string{ @@ -913,6 +915,99 @@ var _ = Describe("Tokens", func() { Expect(err).ToNot(HaveOccurred()) }) }) + + Describe("Test retry for getting access token", func() { + It("Return access token after a few retries", func() { + // Generate tokens: + refreshToken := DefaultToken("Refresh", 10*time.Hour) + accessToken := DefaultToken("Bearer", 5*time.Minute) + + oidServer.AppendHandlers( + ghttp.RespondWith( + http.StatusInternalServerError, + `Internal Server Error`, + http.Header{ + "Content-Type": []string{ + "text/plain", + }, + }, + ), + ghttp.RespondWith( + http.StatusBadGateway, + `Bad Gateway`, + http.Header{ + "Content-Type": []string{ + "text/plain", + }, + }, + ), + ghttp.CombineHandlers( + VerifyRefreshGrant(refreshToken), + RespondWithTokens(accessToken, refreshToken), + ), + ) + + // Create the connection: + connection, err := NewConnectionBuilder(). + Logger(logger). + TokenURL(oidServer.URL()). + URL(apiServer.URL()). + Tokens(refreshToken). + Build() + Expect(err).ToNot(HaveOccurred()) + defer connection.Close() + + // Get the tokens: + returnedAccess, returnedRefresh, err := connection.Tokens() + Expect(err).ToNot(HaveOccurred()) + Expect(returnedAccess).ToNot(BeEmpty()) + Expect(returnedRefresh).ToNot(BeEmpty()) + }) + It("Test no retry when status is not http 5xx", func() { + // Generate tokens: + refreshToken := DefaultToken("Refresh", 10*time.Hour) + accessToken := DefaultToken("Bearer", 5*time.Minute) + + oidServer.AppendHandlers( + ghttp.RespondWith( + http.StatusInternalServerError, + `Internal Server Error`, + http.Header{ + "Content-Type": []string{ + "text/plain", + }, + }, + ), + ghttp.RespondWith( + http.StatusForbidden, + `{}`, + http.Header{ + "Content-Type": []string{ + "application/json", + }, + }, + ), + ghttp.CombineHandlers( + VerifyRefreshGrant(refreshToken), + RespondWithTokens(accessToken, refreshToken), + ), + ) + + // Create the connection: + connection, err := NewConnectionBuilder(). + Logger(logger). + TokenURL(oidServer.URL()). + URL(apiServer.URL()). + Tokens(refreshToken). + Build() + Expect(err).ToNot(HaveOccurred()) + defer connection.Close() + + // Get the tokens: + _, _, err = connection.Tokens() + Expect(err).To(HaveOccurred()) + }) + }) }) func VerifyPasswordGrant(user, password string) http.HandlerFunc {