diff --git a/client/client.go b/client/client.go index 175a57351..20bd90b4b 100644 --- a/client/client.go +++ b/client/client.go @@ -62,9 +62,15 @@ func (c *DatabricksClient) GetOAuthToken(ctx context.Context, authDetails string } // Do sends an HTTP request against path. -func (c *DatabricksClient) Do(ctx context.Context, method, path string, - headers map[string]string, request, response any, - visitors ...func(*http.Request) error) error { +func (c *DatabricksClient) Do( + ctx context.Context, + method string, + path string, + headers map[string]string, + request any, + response any, + visitors ...func(*http.Request) error, +) error { opts := []httpclient.DoOption{} for _, v := range visitors { opts = append(opts, httpclient.WithRequestVisitor(v)) diff --git a/config/api_client.go b/config/api_client.go index 0913e35b3..6825d630f 100644 --- a/config/api_client.go +++ b/config/api_client.go @@ -6,9 +6,11 @@ import ( "fmt" "net/http" "net/url" + "regexp" "time" "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/common" "github.com/databricks/databricks-sdk-go/credentials" "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/useragent" @@ -73,17 +75,22 @@ func (c *Config) NewApiClient() (*httpclient.ApiClient, error) { return nil }, }, - TransientErrors: []string{ - "REQUEST_LIMIT_EXCEEDED", // This is temporary workaround for SCIM API returning 500. Remove when it's fixed - }, ErrorMapper: apierr.GetAPIError, - ErrorRetriable: func(ctx context.Context, err error) bool { - var apiErr *apierr.APIError - if errors.As(err, &apiErr) { - return apiErr.IsRetriable(ctx) - } - return false - }, + ErrorRetriable: httpclient.CombineRetriers( + func(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool { + var apiErr *apierr.APIError + if errors.As(err, &apiErr) { + return apiErr.IsRetriable(ctx) + } + return false + }, + httpclient.RetryUrlErrors, + httpclient.RetryTransientErrors([]string{"REQUEST_LIMIT_EXCEEDED"}), + httpclient.RetryMatchedRequests([]httpclient.RestApiMatcher{ + // Get Permissions API can be retried on 504 + {Method: http.MethodGet, Path: *regexp.MustCompile(`/api/2.0/permissions/[^/]+/[^/]+`)}, + }, httpclient.RetryOnGatewayTimeout), + ), }), nil } diff --git a/config/api_client_test.go b/config/api_client_test.go new file mode 100644 index 000000000..44b7648f3 --- /dev/null +++ b/config/api_client_test.go @@ -0,0 +1,54 @@ +package config + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + + "github.com/databricks/databricks-sdk-go/httpclient" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type hc func(r *http.Request) (*http.Response, error) + +func (cb hc) RoundTrip(r *http.Request) (*http.Response, error) { + return cb(r) +} + +func (cb hc) SkipRetryOnIO() bool { + return true +} + +func TestApiClient_RetriesGetPermissionsOnGatewayTimeout(t *testing.T) { + requestCount := 0 + c := &Config{ + HTTPTransport: hc(func(r *http.Request) (*http.Response, error) { + initialRequestCount := requestCount + requestCount++ + if initialRequestCount == 0 { + return &http.Response{ + Request: r, + StatusCode: http.StatusGatewayTimeout, + Body: io.NopCloser(strings.NewReader( + fmt.Sprintf(`{"error_code":"TEMPORARILY_UNAVAILABLE", "message":"The service at %s is taking too long to process your request. Please try again later or try a faster operation."}`, r.URL))), + }, nil + } + return &http.Response{ + Request: r, + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"permissions": ["can_run_queries"]}`)), + }, nil + }), + } + client, err := c.NewApiClient() + require.NoError(t, err) + ctx := context.Background() + var res map[string][]string + err = client.Do(ctx, "GET", "/api/2.0/permissions/object/id", httpclient.WithResponseUnmarshal(&res)) + assert.NoError(t, err) + assert.Equal(t, map[string][]string{"permissions": {"can_run_queries"}}, res) +} diff --git a/config/config.go b/config/config.go index fcf69d2cb..f2c948844 100644 --- a/config/config.go +++ b/config/config.go @@ -311,13 +311,16 @@ func (c *Config) EnsureResolved() error { HTTPTimeout: time.Duration(c.HTTPTimeoutSeconds) * time.Second, Transport: c.HTTPTransport, ErrorMapper: c.refreshTokenErrorMapper, - TransientErrors: []string{ - "throttled", - "too many requests", - "429", - "request limit exceeded", - "rate limit", - }, + ErrorRetriable: httpclient.CombineRetriers( + httpclient.DefaultErrorRetriable, + httpclient.RetryTransientErrors([]string{ + "throttled", + "too many requests", + "429", + "request limit exceeded", + "rate limit", + }), + ), }) if c.azureTenantIdFetchClient == nil { c.azureTenantIdFetchClient = &http.Client{ diff --git a/httpclient/api_client.go b/httpclient/api_client.go index 2130fd3bb..b62ebdba4 100644 --- a/httpclient/api_client.go +++ b/httpclient/api_client.go @@ -9,7 +9,6 @@ import ( "net/http" "net/url" "runtime" - "strings" "time" "github.com/databricks/databricks-sdk-go/common" @@ -28,16 +27,25 @@ type ClientConfig struct { AuthVisitor RequestVisitor Visitors []RequestVisitor - RetryTimeout time.Duration + // The maximum amount of time to retry requests that return retriable errors. + // If unset, the default is 5 minutes. + RetryTimeout time.Duration + + // Returns the amount of time to wait after the given attempt. + RetryBackoff retries.BackoffFunc + HTTPTimeout time.Duration InsecureSkipVerify bool DebugHeaders bool DebugTruncateBytes int RateLimitPerSecond int - ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error - ErrorRetriable func(ctx context.Context, err error) bool - TransientErrors []string + // ErrorMapper converts the API response into a Go error if the response is an error. + ErrorMapper func(ctx context.Context, resp common.ResponseWrapper) error + + // ErrorRetriable determines if the API request should be retried. It is not + // called if the context is cancelled or if the request succeeded. + ErrorRetriable ErrorRetrier Transport http.RoundTripper } @@ -130,7 +138,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio // merge client-wide and request-specific visitors visitors = append(visitors, o.in) } - } // Use default AuthVisitor if none is provided if authVisitor == nil { @@ -170,45 +177,6 @@ func (c *ApiClient) Do(ctx context.Context, method, path string, opts ...DoOptio return nil } -func (c *ApiClient) isRetriable(ctx context.Context, err error) bool { - if c.config.ErrorRetriable(ctx, err) { - return true - } - if isRetriableUrlError(err) { - // all IO errors are retriable - logger.Debugf(ctx, "Attempting retry because of IO error: %s", err) - return true - } - message := err.Error() - // Handle transient errors for retries - for _, substring := range c.config.TransientErrors { - if strings.Contains(message, substring) { - logger.Debugf(ctx, "Attempting retry because of %#v", substring) - return true - } - } - // some API's recommend retries on HTTP 500, but we'll add that later - return false -} - -// Common error-handling logic for all responses that may need to be retried. -// -// If the error is retriable, return a retries.Err to retry the request. However, as the request body will have been consumed -// by the first attempt, the body must be reset before retrying. If the body cannot be reset, return a retries.Err to halt. -// -// Always returns nil for the first parameter as there is no meaningful response body to return in the error case. -// -// If it is certain that an error should not be retried, use failRequest() instead. -func (c *ApiClient) handleError(ctx context.Context, err error, body common.RequestBody) (*common.ResponseWrapper, *retries.Err) { - if !c.isRetriable(ctx, err) { - return nil, retries.Halt(err) - } - if resetErr := body.Reset(); resetErr != nil { - return nil, retries.Halt(resetErr) - } - return nil, retries.Continue(err) -} - // Fails the request with a retries.Err to halt future retries. func (c *ApiClient) failRequest(msg string, err error) (*common.ResponseWrapper, *retries.Err) { err = fmt.Errorf("%s: %w", msg, err) @@ -299,7 +267,16 @@ func (c *ApiClient) attempt( // proactively release the connections in HTTP connection pool c.httpClient.CloseIdleConnections() - return c.handleError(ctx, err, requestBody) + + // Non-retriable errors can be returned immediately. + if !c.config.ErrorRetriable(ctx, request, &responseWrapper, err) { + return nil, retries.Halt(err) + } + // Retriable errors may require the request body to be reset. + if resetErr := requestBody.Reset(); resetErr != nil { + return nil, retries.Halt(resetErr) + } + return nil, retries.Continue(err) } } @@ -331,16 +308,24 @@ func (c *ApiClient) recordRequestLog( func (c *ApiClient) RoundTrip(request *http.Request) (*http.Response, error) { ctx := request.Context() requestURL := request.URL.String() - resp, err := retries.Poll(ctx, c.config.RetryTimeout, - c.attempt(ctx, request.Method, requestURL, common.RequestBody{ - Reader: request.Body, - // DO NOT DECODE BODY, because it may contain sensitive payload, - // like Azure Service Principal in a multipart/form-data body. - DebugBytes: []byte(""), - }, func(r *http.Request) error { - r.Header = request.Header - return nil - })) + retrier := makeRetrier[common.ResponseWrapper](c.config) + resp, err := retrier.Run( + ctx, + func(ctx context.Context) (*common.ResponseWrapper, error) { + resp, err := c.attempt(ctx, request.Method, requestURL, common.RequestBody{ + Reader: request.Body, + // DO NOT DECODE BODY, because it may contain sensitive payload, + // like Azure Service Principal in a multipart/form-data body. + DebugBytes: []byte(""), + }, func(r *http.Request) error { + r.Header = request.Header + return nil + })() + if err != nil { + return nil, err + } + return resp, nil + }) if err != nil { return nil, err } @@ -365,8 +350,16 @@ func (c *ApiClient) perform( requestBody common.RequestBody, visitors ...RequestVisitor, ) (*common.ResponseWrapper, error) { - resp, err := retries.Poll(ctx, c.config.RetryTimeout, - c.attempt(ctx, method, requestURL, requestBody, visitors...)) + retrier := makeRetrier[common.ResponseWrapper](c.config) + resp, err := retrier.Run( + ctx, + func(ctx context.Context) (*common.ResponseWrapper, error) { + resp, err := c.attempt(ctx, method, requestURL, requestBody, visitors...)() + if err != nil { + return resp, err + } + return resp, nil + }) var timedOut *retries.ErrTimedOut if errors.As(err, &timedOut) { // TODO: check if we want to unwrap this error here diff --git a/httpclient/errors.go b/httpclient/errors.go index 540c6b885..2d9ac6a73 100644 --- a/httpclient/errors.go +++ b/httpclient/errors.go @@ -7,9 +7,11 @@ import ( "io" "net/http" "net/url" + "regexp" "strings" "github.com/databricks/databricks-sdk-go/common" + "github.com/databricks/databricks-sdk-go/logger" ) type HttpError struct { @@ -45,17 +47,46 @@ func DefaultErrorMapper(ctx context.Context, resp common.ResponseWrapper) error } } -func DefaultErrorRetriable(ctx context.Context, err error) bool { - var httpError *HttpError - if errors.As(err, &httpError) { - if httpError.StatusCode == http.StatusTooManyRequests { - return true - } - if httpError.StatusCode == http.StatusGatewayTimeout { - return true +// ErrorRetrier determines whether a request should be retried. The request should be retried if +// and only if the function returns true. +type ErrorRetrier func(context.Context, *http.Request, *common.ResponseWrapper, error) bool + +// DefaultErrorRetriable is the ErrorRetrier used if none is specified. It retries on 429 and 504 errors. +func DefaultErrorRetriable(ctx context.Context, req *http.Request, resp *common.ResponseWrapper, err error) bool { + return CombineRetriers( + RetryOnTooManyRequests, + RetryOnGatewayTimeout, + RetryUrlErrors, + )(ctx, req, resp, err) +} + +// RetryOnTooManyRequests retries when the response status code is 429. +func RetryOnTooManyRequests(ctx context.Context, _ *http.Request, resp *common.ResponseWrapper, err error) bool { + if resp.Response == nil { + return false + } + return resp.Response.StatusCode == http.StatusTooManyRequests +} + +// RetryOnGatewayTimeout retries when the response status code is 504. +func RetryOnGatewayTimeout(ctx context.Context, _ *http.Request, resp *common.ResponseWrapper, err error) bool { + if resp.Response == nil { + return false + } + return resp.Response.StatusCode == http.StatusGatewayTimeout +} + +// CombineRetriers combines multiple ErrorRetriers into a single ErrorRetrier. The combined ErrorRetrier +// will return true if any of the input ErrorRetriers return true. +func CombineRetriers(retriers ...ErrorRetrier) ErrorRetrier { + return func(ctx context.Context, req *http.Request, resp *common.ResponseWrapper, err error) bool { + for _, retrier := range retriers { + if retrier(ctx, req, resp, err) { + return true + } } + return false } - return false } var urlErrorTransientErrorMessages = []string{ @@ -66,15 +97,58 @@ var urlErrorTransientErrorMessages = []string{ "i/o timeout", } -func isRetriableUrlError(err error) bool { +// RetryUrlErrors retries when the error is a *url.Error with a transient error message. +func RetryUrlErrors(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool { var urlError *url.Error if !errors.As(err, &urlError) { return false } for _, msg := range urlErrorTransientErrorMessages { if strings.Contains(err.Error(), msg) { + logger.Debugf(ctx, "Attempting retry because of IO error: %s", err) return true } } return false } + +// RetryTransientErrors retries when the error message contains any of the provided substrings. +func RetryTransientErrors(errors []string) ErrorRetrier { + return func(ctx context.Context, _ *http.Request, _ *common.ResponseWrapper, err error) bool { + message := err.Error() + // Handle transient errors for retries + for _, substring := range errors { + if strings.Contains(message, substring) { + logger.Debugf(ctx, "Attempting retry because of %#v", substring) + return true + } + } + return false + } +} + +// RestApiMatcher matches a request based on the HTTP method and path. +type RestApiMatcher struct { + // Method is the HTTP method to match. + Method string + // Path is the regular expression to match the path. + Path regexp.Regexp +} + +// Matches returns true if the request matches the method and path. +func (m *RestApiMatcher) Matches(req *http.Request) bool { + return req.Method == m.Method && m.Path.MatchString(req.URL.Path) +} + +// RetryMatchedRequests applies a retrier that only applies to requests matching one of the provided matchers. +func RetryMatchedRequests(matchers []RestApiMatcher, retryer ErrorRetrier) ErrorRetrier { + return func(ctx context.Context, r *http.Request, rw *common.ResponseWrapper, err error) bool { + for _, m := range matchers { + if m.Matches(r) && retryer(ctx, r, rw, err) { + logger.Debugf(ctx, "Attempting retry because of gateway timeout") + return true + } + } + return false + } +} diff --git a/httpclient/errors_test.go b/httpclient/errors_test.go index a9227ced5..775011b4b 100644 --- a/httpclient/errors_test.go +++ b/httpclient/errors_test.go @@ -6,11 +6,134 @@ import ( "net/http" "strings" "testing" + "time" + "github.com/databricks/databricks-sdk-go/common" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type mock struct { + MaxFails int // number of times the failed Response is returned + FailResponse *http.Response // response to return in case of fail + FailError error // error to return in case of fail + NumCalls int // total number of calls +} + +func (m *mock) RoundTrip(r *http.Request) (*http.Response, error) { + m.NumCalls++ + if m.NumCalls <= m.MaxFails { + resp := *m.FailResponse + resp.Request = r + return &resp, m.FailError + } + return &http.Response{ + Request: r, + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{}`)), + }, nil +} + +func (m *mock) SkipRetryOnIO() bool { + return true +} + +func TestApiClient_Do_retries(t *testing.T) { + testCases := []struct { + desc string + mock *mock + errorRetrier ErrorRetrier + wantErrorMsg string + wantNumCalls int + }{ + { + desc: "default retrier retries on 429", + mock: &mock{ + MaxFails: 1, + FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}, + }, + wantNumCalls: 2, + }, + { + desc: "default retrier retries on 504", + mock: &mock{ + MaxFails: 1, + FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}, + }, + wantNumCalls: 2, + }, + { + desc: "default retrier does not retry on 503", + mock: &mock{ + MaxFails: 1, + FailResponse: &http.Response{StatusCode: http.StatusServiceUnavailable}, + }, + wantErrorMsg: "http 503: ", + wantNumCalls: 1, + }, + { + desc: "no retry when ErrorRetriable returns false", + mock: &mock{ + MaxFails: 1, + FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}, + }, + errorRetrier: func(context.Context, *http.Request, *common.ResponseWrapper, error) bool { + return false + }, + wantErrorMsg: "http 504: ", + wantNumCalls: 1, + }, + { + desc: "retry 1 time", + mock: &mock{ + MaxFails: 1, + FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}, + }, + errorRetrier: func(context.Context, *http.Request, *common.ResponseWrapper, error) bool { + return true + }, + wantNumCalls: 2, + }, + { + desc: "retry multiple times", + mock: &mock{ + MaxFails: 3, + FailResponse: &http.Response{StatusCode: http.StatusGatewayTimeout}, + }, + errorRetrier: func(_ context.Context, _ *http.Request, _ *common.ResponseWrapper, _ error) bool { + return true + }, + wantNumCalls: 4, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + clientConfig := ClientConfig{ + Transport: tc.mock, + ErrorRetriable: tc.errorRetrier, + RetryBackoff: func(_ int) time.Duration { + // Do not wait to retry in tests + return 0 + }, + } + client := NewApiClient(clientConfig) + + err := client.Do(context.Background(), "GET", "test-path") + if tc.wantErrorMsg == "" { + assert.NoError(t, err) + } else { + assert.Contains(t, err.Error(), tc.wantErrorMsg) + } + gotNumCalls := tc.mock.NumCalls + + if gotNumCalls != tc.wantNumCalls { + t.Errorf("got %d calls, want %d", gotNumCalls, tc.wantNumCalls) + } + }) + } +} + func TestSimpleRequestAPIError(t *testing.T) { c := NewApiClient(ClientConfig{ Transport: hc(func(r *http.Request) (*http.Response, error) { diff --git a/httpclient/retries.go b/httpclient/retries.go new file mode 100644 index 000000000..0f150a3af --- /dev/null +++ b/httpclient/retries.go @@ -0,0 +1,13 @@ +package httpclient + +import ( + "github.com/databricks/databricks-sdk-go/retries" +) + +func makeRetrier[T any](c ClientConfig) retries.Retrier[T] { + return retries.New[T]( + retries.WithTimeout(c.RetryTimeout), + retries.WithRetryFunc(retries.DefaultShouldRetry), + retries.WithBackoffFunc(c.RetryBackoff), + ) +} diff --git a/retries/retries.go b/retries/retries.go index 8517e0e07..dfec6ac0a 100644 --- a/retries/retries.go +++ b/retries/retries.go @@ -67,6 +67,9 @@ func Continuef(format string, err error, args ...interface{}) *Err { return Continue(wrapped) } +// BackoffFunc is a function that returns the duration to wait before retrying the given attempt. +type BackoffFunc func(int) time.Duration + var maxWait = 10 * time.Second var minJitter = 50 * time.Millisecond var maxJitter = 750 * time.Millisecond @@ -155,6 +158,14 @@ func WithRetryFunc(halt func(error) bool) RetryOption { } } +// WithBackoffFunc configures the backoff duration for a given attempt. The retrier will wait +// for the returned duration before retrying. +func WithBackoffFunc(f func(attempt int) time.Duration) RetryOption { + return func(rc *RetryConfig) { + rc.backoff = f + } +} + // Retrier is a struct that can retry an operation until it succeeds or the timeout is reached. // The empty struct indicates that the retrier should run for 20 minutes and retry on any non-nil error. // The type parameter is the return type of the Run() method. When using the Wait() method, this can be struct{}. @@ -229,7 +240,7 @@ func (r Retrier[T]) Run(ctx context.Context, fn func(context.Context) (*T, error } } -func shouldRetry(err error) bool { +func DefaultShouldRetry(err error) bool { if err == nil { return false } @@ -241,7 +252,7 @@ func shouldRetry(err error) bool { } func Wait(ctx context.Context, timeout time.Duration, fn func() *Err) error { - return New[struct{}](WithTimeout(timeout), WithRetryFunc(shouldRetry)).Wait(ctx, func(_ context.Context) error { + return New[struct{}](WithTimeout(timeout), WithRetryFunc(DefaultShouldRetry)).Wait(ctx, func(_ context.Context) error { err := fn() if err != nil { return err @@ -251,7 +262,7 @@ func Wait(ctx context.Context, timeout time.Duration, fn func() *Err) error { } func Poll[T any](ctx context.Context, timeout time.Duration, fn func() (*T, *Err)) (*T, error) { - return New[T](WithTimeout(timeout), WithRetryFunc(shouldRetry)).Run(ctx, func(_ context.Context) (*T, error) { + return New[T](WithTimeout(timeout), WithRetryFunc(DefaultShouldRetry)).Run(ctx, func(_ context.Context) (*T, error) { res, err := fn() if err != nil { return res, err