diff --git a/commands/account.go b/commands/account.go index 4b6fba410..570e3f417 100644 --- a/commands/account.go +++ b/commands/account.go @@ -14,7 +14,10 @@ limitations under the License. package commands import ( + "fmt" + "github.com/digitalocean/doctl/commands/displayers" + "github.com/digitalocean/doctl/do" "github.com/spf13/cobra" ) @@ -64,6 +67,20 @@ func RunAccountGet(c *CmdConfig) error { // RunAccountRateLimit retrieves API rate limits for the account. func RunAccountRateLimit(c *CmdConfig) error { + // We disable reties by replacing the HTTPClient as we only want the + // rate-limit headers regardless of response status. Without doing so, + // we would retry until retries were exhausted if rate-limited delaying a + // response for no purpose. + if RetryMax > 0 { + accessToken := c.getContextAccessToken() + godoClient, err := c.Doit.GetGodoClient(Trace, false, accessToken) + if err != nil { + return fmt.Errorf("Unable to initialize DigitalOcean API client: %s", err) + } + + c.Account = func() do.AccountService { return do.NewAccountService(godoClient) } + } + rl, err := c.Account().RateLimit() if err != nil { return err diff --git a/commands/account_test.go b/commands/account_test.go index 8ff014b83..9fe997c62 100644 --- a/commands/account_test.go +++ b/commands/account_test.go @@ -59,6 +59,7 @@ func TestAccountGet(t *testing.T) { func TestAccountGetRateLimit(t *testing.T) { withTestClient(t, func(config *CmdConfig, tm *tcMocks) { + RetryMax = 0 now := time.Now() testRateLimit.Reset = godo.Timestamp{Time: now} tm.account.EXPECT().RateLimit().Return(testRateLimit, nil) diff --git a/commands/command_config.go b/commands/command_config.go index a67fd732c..b1ff12ffc 100644 --- a/commands/command_config.go +++ b/commands/command_config.go @@ -85,7 +85,7 @@ func NewCmdConfig(ns string, dc doctl.Config, out io.Writer, args []string, init initServices: func(c *CmdConfig) error { accessToken := c.getContextAccessToken() - godoClient, err := c.Doit.GetGodoClient(Trace, accessToken) + godoClient, err := c.Doit.GetGodoClient(Trace, true, accessToken) if err != nil { return fmt.Errorf("Unable to initialize DigitalOcean API client: %s", err) } diff --git a/commands/doit.go b/commands/doit.go index 8f7355ed9..fdda3360a 100644 --- a/commands/doit.go +++ b/commands/doit.go @@ -61,6 +61,11 @@ var ( //Interactive toggle interactive behavior Interactive bool + // Retry settings to pass through to godo.RetryConfig + RetryMax int + RetryWaitMax int + RetryWaitMin int + requiredColor = color.New(color.Bold).SprintfFunc() ) @@ -99,6 +104,17 @@ func init() { } rootPFlagSet.BoolVarP(&Interactive, doctl.ArgInteractive, "", interactive, interactiveHelpText) + rootPFlagSet.IntVar(&RetryMax, "http-retry-max", 5, "Set maximum number of retries for requests that fail with a 429 or 500-level error") + viper.BindPFlag("http-retry-max", rootPFlagSet.Lookup("http-retry-max")) + + rootPFlagSet.IntVar(&RetryWaitMax, "http-retry-wait-max", 30, "Set the minimum number of seconds to wait before retrying a failed request") + viper.BindPFlag("http-retry-wait-max", rootPFlagSet.Lookup("http-retry-wait-max")) + DoitCmd.PersistentFlags().MarkHidden("http-retry-wait-max") + + rootPFlagSet.IntVar(&RetryWaitMin, "http-retry-wait-min", 1, "Set the maximum number of seconds to wait before retrying a failed request") + viper.BindPFlag("http-retry-wait-min", rootPFlagSet.Lookup("http-retry-wait-min")) + DoitCmd.PersistentFlags().MarkHidden("http-retry-wait-min") + addCommands() cobra.OnInitialize(initConfig) diff --git a/doit.go b/doit.go index 0bb712725..e31873c60 100644 --- a/doit.go +++ b/doit.go @@ -207,7 +207,7 @@ func (glv *GithubLatestVersioner) LatestVersion() (string, error) { // Config is an interface that represent doit's config. type Config interface { - GetGodoClient(trace bool, accessToken string) (*godo.Client, error) + GetGodoClient(trace, allowRetries bool, accessToken string) (*godo.Client, error) GetDockerEngineClient() (builder.DockerEngineClient, error) SSH(user, host, keyPath string, port int, opts ssh.Options) runner.Runner Listen(url *url.URL, token string, schemaFunc listen.SchemaFunc, out io.Writer) listen.ListenerService @@ -231,7 +231,7 @@ type LiveConfig struct { var _ Config = &LiveConfig{} // GetGodoClient returns a GodoClient. -func (c *LiveConfig) GetGodoClient(trace bool, accessToken string) (*godo.Client, error) { +func (c *LiveConfig) GetGodoClient(trace, allowRetries bool, accessToken string) (*godo.Client, error) { if accessToken == "" { return nil, fmt.Errorf("access token is required. (hint: run 'doctl auth init')") } @@ -239,31 +239,63 @@ func (c *LiveConfig) GetGodoClient(trace bool, accessToken string) (*godo.Client tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: accessToken}) oauthClient := oauth2.NewClient(context.Background(), tokenSource) + args := []godo.ClientOpt{ + godo.SetUserAgent(userAgent()), + } + + logger := log.New(os.Stderr, "doctl: ", log.LstdFlags) + + retryMax := viper.GetInt("http-retry-max") + retryWaitMax := viper.GetInt("http-retry-wait-max") + retryWaitMin := viper.GetInt("http-retry-wait-min") + if retryMax > 0 && allowRetries { + retryConfig := godo.RetryConfig{ + RetryMax: retryMax, + } + + if retryWaitMax > 0 { + retryConfig.RetryWaitMax = godo.PtrTo(float64(retryWaitMax)) + } + + if retryWaitMin > 0 { + retryConfig.RetryWaitMin = godo.PtrTo(float64(retryWaitMin)) + } + + if trace { + retryConfig.Logger = logger + } + + args = append(args, godo.WithRetryAndBackoffs(retryConfig)) + } + + apiURL := viper.GetString("api-url") + if apiURL != "" { + args = append(args, godo.SetBaseURL(apiURL)) + } + + client, err := godo.New(oauthClient, args...) + if err != nil { + return nil, err + } + if trace { - r := newRecorder(oauthClient.Transport) + r := newRecorder(client.HTTPClient.Transport) go func() { for { select { case msg := <-r.req: - log.Println("->", strconv.Quote(msg)) + logger.Println("->", strconv.Quote(msg)) case msg := <-r.resp: - log.Println("<-", strconv.Quote(msg)) + logger.Println("<-", strconv.Quote(msg)) } } }() - oauthClient.Transport = r - } - - args := []godo.ClientOpt{godo.SetUserAgent(userAgent())} - - apiURL := viper.GetString("api-url") - if apiURL != "" { - args = append(args, godo.SetBaseURL(apiURL)) + client.HTTPClient.Transport = r } - return godo.New(oauthClient, args...) + return client, nil } // GetDockerEngineClient returns a container engine client. @@ -464,7 +496,7 @@ func NewTestConfig() *TestConfig { // GetGodoClient mocks a GetGodoClient call. The returned godo client will // be nil. -func (c *TestConfig) GetGodoClient(trace bool, accessToken string) (*godo.Client, error) { +func (c *TestConfig) GetGodoClient(trace, allowRetries bool, accessToken string) (*godo.Client, error) { return &godo.Client{}, nil } diff --git a/integration/account_test.go b/integration/account_test.go index 53997bc2d..6edd77b5a 100644 --- a/integration/account_test.go +++ b/integration/account_test.go @@ -167,7 +167,7 @@ var _ = suite("account/ratelimit", func(t *testing.T, when spec.G, it spec.S) { expect.Equal(expectedOutput, strings.TrimSpace(string(output))) }) - it("doesn't return an error when rate-limted", func() { + it("doesn't return an error when rate-limited", func() { cmd := exec.Command(builtBinaryPath, "-t", "token-with-ratelimit-exhausted", "-u", server.URL, @@ -176,12 +176,26 @@ var _ = suite("account/ratelimit", func(t *testing.T, when spec.G, it spec.S) { ) output, err := cmd.CombinedOutput() - expect.NoError(err) + expect.NoError(err, string(output)) t := time.Unix(1565385881, 0) expectedOutput := strings.TrimSpace(fmt.Sprintf(ratelimitExhaustedOutput, t)) expect.Equal(expectedOutput, strings.TrimSpace(string(output))) }) + + it("doesn't retry when rate-limited", func() { + cmd := exec.Command(builtBinaryPath, + "-t", "token-with-ratelimit-exhausted", + "-u", server.URL, + "account", + "ratelimit", "--trace", + ) + + output, err := cmd.CombinedOutput() + expect.NoError(err, string(output)) + + expect.NotContains(strings.TrimSpace(string(output)), "retrying in") + }) }) const ( diff --git a/integration/projects_delete_test.go b/integration/projects_delete_test.go index 63ec9697e..a45d3a971 100644 --- a/integration/projects_delete_test.go +++ b/integration/projects_delete_test.go @@ -55,6 +55,7 @@ var _ = suite("projects/delete", func(t *testing.T, when spec.G, it spec.S) { "test-project-1", "test-project-2", "-f", + "--http-retry-max", "0", ) output, err := cmd.CombinedOutput() diff --git a/integration/retry_flag_test.go b/integration/retry_flag_test.go new file mode 100644 index 000000000..bbfc0a81c --- /dev/null +++ b/integration/retry_flag_test.go @@ -0,0 +1,145 @@ +package integration + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/http/httputil" + "os/exec" + "strings" + "testing" + + "github.com/sclevine/spec" + "github.com/stretchr/testify/require" +) + +var _ = suite("retries/server-error", func(t *testing.T, when spec.G, it spec.S) { + var ( + expect *require.Assertions + server *httptest.Server + ) + + it.Before(func() { + var ( + requestCount int + errResp = `{"id": "server_error", "message": "something broke"}` + ) + expect = require.New(t) + + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.Header().Add("content-type", "application/json") + + switch req.URL.Path { + case "/v2/account": + requestCount++ + + auth := req.Header.Get("Authorization") + if auth != "Bearer some-magic-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + if req.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + if requestCount < 5 { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(errResp)) + return + } + + w.Write([]byte(accountGetResponse)) + default: + dump, err := httputil.DumpRequest(req, true) + if err != nil { + t.Fatal("failed to dump request") + } + + t.Fatalf("received unknown request: %s", dump) + } + })) + }) + + it("retries five time by default and succeeds", func() { + cmd := exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "account", + "get", + ) + + output, err := cmd.CombinedOutput() + expect.NoError(err) + + expect.Equal(strings.TrimSpace(accountOutput), strings.TrimSpace(string(output))) + }) + + it("retries are logged with trace flag", func() { + cmd := exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "account", + "get", + "--trace", + ) + + output, err := cmd.CombinedOutput() + expect.NoError(err) + + expect.Contains(strings.TrimSpace(string(output)), "retrying in") + }) + + when("respects the http-retry-max flag and gives up", func() { + it("only displays the correct fields", func() { + cmd := exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "account", + "get", + "--http-retry-max", "2", + ) + + output, err := cmd.CombinedOutput() + expect.Error(err) + expectedErr := fmt.Sprintf("Error: GET %s/v2/account: 500 something broke; giving up after 3 attempt(s)", server.URL) + expect.Equal(strings.TrimSpace(string(output)), expectedErr) + }) + }) + + when("retries are disabled when http-retry-max is set to 0", func() { + it("only displays the correct fields", func() { + cmd := exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "account", + "get", + "--http-retry-max", "0", + ) + + output, err := cmd.CombinedOutput() + expect.Error(err) + + // Does not contain "giving up after" + expectedErr := fmt.Sprintf("Error: GET %s/v2/account: 500 something broke", server.URL) + expect.Equal(strings.TrimSpace(string(output)), expectedErr) + }) + }) + + when("doesn't retry 400-level errors", func() { + it("only displays the correct fields", func() { + cmd := exec.Command(builtBinaryPath, + "-t", "bad-token", + "-u", server.URL, + "account", + "get", + ) + + output, err := cmd.CombinedOutput() + expect.Error(err) + + expect.NotContains(strings.TrimSpace(string(output)), "giving up after") + expect.Contains(strings.TrimSpace(string(output)), "401") + }) + }) +})