From e677e26c3a5ae89768b1573a65081079512a2a99 Mon Sep 17 00:00:00 2001 From: Victor Schappert Date: Sun, 28 Mar 2021 10:52:55 -0700 Subject: [PATCH] Fix edge case where racing confuses retry.Times In the pre-racing v1.0.x world, it made sense to base the retry.Times DeciderFunc result on the execution's `Attempt` field, because `Attempt` as monotonically increasing, and by the time the retry policy's Decider was tested, `Attempt` was guaranteed to be equal to the number of retries. ...But... With v1.1.x which has the racing feature, there are some odd cases where using `Attempt` will create an undesirable result. In particular when you base things on `Attempt`, retry.Times would work differently depending on which of two racing attempts finished first even if in both cases the number of finished attempts was the same. Hence in this commit we introduce a new `AttemptEnds` field to ensure retry.Times decides based on number of total attempts completed, not which one completed first. Before this commit, Example #1 below would not cancel Attempt 1 after Attempt 0 completed, but Example #2 would cancel Attempt 0 after Attempt 1 completed, even though they are both logically equivalent from a retry standpoint. After this commit, both examples behave the same. Example 1: Retry policy uses retry.Times(1) Wave 0: +-----------------------------------------> Attempt 0: +----------------------> Attempt 1: +-----------------------------> Example 2: Retry policy uses retry.Times(1) Wave 0: +-----------------------------------------> Attempt 0: +----------------------------------> Attempt 1: +----------> --- client.go | 8 +++++--- client_test.go | 47 ++++++++++++++++++++++++++++++++++--------- request/execution.go | 8 ++++++-- retry/decider.go | 6 +++--- retry/decider_test.go | 20 ++++++++++++++---- retry/policy_test.go | 7 ++++--- 6 files changed, 72 insertions(+), 24 deletions(-) diff --git a/client.go b/client.go index 215c957..8c13ce6 100644 --- a/client.go +++ b/client.go @@ -286,6 +286,7 @@ func (es *execState) handleCheckpoint(attempt *attemptState) (drain bool, halt b es.handlers.run(AfterAttemptTimeout, es.exec) } es.handlers.run(AfterAttempt, es.exec) + es.exec.AttemptEnds++ es.exec.Racing-- attempt.resp = es.exec.Response attempt.body = es.exec.Body @@ -294,6 +295,7 @@ func (es *execState) handleCheckpoint(attempt *attemptState) (drain bool, halt b halt = attempt.redundant || es.planCancelled() || !es.retryPolicy.Decide(es.exec) return case panicked: + es.exec.AttemptEnds++ es.exec.Racing-- panic(attempt.panicVal) default: @@ -370,6 +372,7 @@ func (es *execState) cleanupWave() { as.cancel(false) switch as.checkpoint { case createdRequest, readBodyHandle: + es.exec.AttemptEnds++ es.exec.Racing-- case sentRequestHandle: as.ready <- struct{}{} @@ -386,9 +389,8 @@ func (es *execState) cleanupWave() { continue } fallthrough - case readBody: - es.exec.Racing-- - case panicked: + case readBody, panicked: + es.exec.AttemptEnds++ es.exec.Racing-- default: panic("httpx: bad attempt checkpoint") diff --git a/client_test.go b/client_test.go index c43fa13..d49e73f 100644 --- a/client_test.go +++ b/client_test.go @@ -127,22 +127,26 @@ func testClientHappyPath(t *testing.T) { cl.Handlers.mock(BeforeExecutionStart).On("Handle", BeforeExecutionStart, mock.MatchedBy(func(e *request.Execution) bool { return e.Start == time.Time{} && + e.AttemptEnds == e.Attempt && e.Plan != nil && e.Request == nil && e.Response == nil && !e.Ended() })).Once() cl.Handlers.mock(BeforeAttempt).On("Handle", BeforeAttempt, mock.MatchedBy(func(e *request.Execution) bool { return !e.Start.Before(before) && !e.Start.After(time.Now()) && - e.Request != nil && e.Response == nil && !e.Ended() + e.AttemptEnds == e.Attempt && e.Request != nil && e.Response == nil && !e.Ended() })).Once() cl.Handlers.mock(BeforeReadBody).On("Handle", BeforeReadBody, mock.MatchedBy(func(e *request.Execution) bool { - return e.Request != nil && e.Response == resp && e.Err == nil && !e.Ended() + return e.Request != nil && e.Response == resp && e.AttemptEnds == e.Attempt && + e.Err == nil && !e.Ended() })).Once() cl.Handlers.mock(AfterAttemptTimeout) // Add so we can assert it was never called. cl.Handlers.mock(AfterAttempt).On("Handle", AfterAttempt, mock.MatchedBy(func(e *request.Execution) bool { - return e.Request != nil && e.Response == resp && e.Err == nil && !e.Ended() + return e.Request != nil && e.Response == resp && e.AttemptEnds == e.Attempt && + e.Err == nil && !e.Ended() })).Once() cl.Handlers.mock(AfterPlanTimeout) // Add so we can assert it was never called. cl.Handlers.mock(AfterExecutionEnd).On("Handle", AfterExecutionEnd, mock.MatchedBy(func(e *request.Execution) bool { - return e.Request != nil && e.Response == resp && e.Err == nil && e.Attempt == 0 && + return e.Request != nil && e.Response == resp && e.Err == nil && + e.Attempt == 0 && e.AttemptEnds == 1 && e.Racing == 0 && e.Wave == 0 && e.Ended() })).Once() @@ -163,6 +167,7 @@ func testClientHappyPath(t *testing.T) { assert.Equal(t, 200, e.StatusCode()) assert.Equal(t, []byte("foo"), e.Body) assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) assert.Equal(t, 0, e.Racing) assert.Equal(t, 0, e.Wave) @@ -195,6 +200,7 @@ func testClientZeroValue(t *testing.T) { assert.Equal(t, 200, e.StatusCode()) assert.Empty(t, e.Body) assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) assert.Equal(t, 0, e.Racing) assert.Equal(t, 0, e.Wave) }, @@ -218,6 +224,7 @@ func testClientZeroValue(t *testing.T) { assert.Equal(t, 404, e.StatusCode()) assert.Equal(t, []byte("the thingy was not in the place"), e.Body) assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) assert.Equal(t, 0, e.Racing) assert.Equal(t, 0, e.Wave) }, @@ -228,7 +235,7 @@ func testClientZeroValue(t *testing.T) { StatusCode: 503, Body: []bodyChunk{ { - Data: []byte("ain't not service in these parts"), + Data: []byte("ain't no service in these parts"), }, }, }, @@ -239,8 +246,9 @@ func testClientZeroValue(t *testing.T) { assert.NotNil(t, e.Request) assert.NotNil(t, e.Response) assert.Equal(t, 503, e.StatusCode()) - assert.Equal(t, []byte("ain't not service in these parts"), e.Body) + assert.Equal(t, []byte("ain't no service in these parts"), e.Body) assert.Equal(t, retry.DefaultTimes, e.Attempt) + assert.Equal(t, retry.DefaultTimes+1, e.AttemptEnds) assert.Equal(t, 0, e.AttemptTimeouts) assert.Equal(t, 0, e.Racing) assert.Equal(t, retry.DefaultTimes, e.Wave) @@ -333,6 +341,7 @@ func testClientAttemptTimeout(t *testing.T) { assert.NotNil(t, e.Body) } assert.Equal(t, e.Attempt, 0) + assert.Equal(t, e.AttemptEnds, 1) assert.Equal(t, e.AttemptTimeouts, 1) assert.Equal(t, 0, e.Racing) assert.Equal(t, 0, e.Wave) @@ -432,6 +441,7 @@ func testClientBodyError(t *testing.T) { assert.Equal(t, 0, e.StatusCode()) } assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) assert.Equal(t, 1, e.AttemptTimeouts) assert.Equal(t, 0, e.Racing) assert.Equal(t, 0, e.Wave) @@ -467,6 +477,8 @@ func testClientBodyError(t *testing.T) { assert.NotNil(t, e.Response) assert.Equal(t, 202, e.StatusCode()) assert.Equal(t, []byte{}, e.Body) + assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) assert.Equal(t, []string{ "BeforeExecutionStart", "BeforeAttempt", @@ -502,7 +514,7 @@ func testClientRetryPlanTimeout(t *testing.T) { mockRetryPolicy.On("Wait", mock.Anything).Return(time.Hour).Maybe() cl.Handlers.mock(AfterPlanTimeout).On("Handle", AfterPlanTimeout, mock.MatchedBy(func(e *request.Execution) bool { err, ok := e.Err.(*url.Error) - return e.Attempt == 0 && e.AttemptTimeouts == 0 && + return e.Attempt == 0 && e.AttemptEnds == 1 && e.AttemptTimeouts == 0 && e.Request != nil && e.Response != nil && e.Body != nil && ok && err.Timeout() })).Return().Once() @@ -529,6 +541,7 @@ func testClientRetryPlanTimeout(t *testing.T) { assert.NotNil(t, e.Response) assert.NotNil(t, e.Body) assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) assert.Equal(t, 0, e.AttemptTimeouts) assert.Equal(t, 0, e.Racing) assert.Equal(t, 0, e.Wave) @@ -671,6 +684,7 @@ func testClientRetryVarious(t *testing.T) { } require.NotNil(t, e.Request) assert.Equal(t, i, e.Attempt) + assert.Equal(t, i+1, e.AttemptEnds) assert.Equal(t, 1, e.AttemptTimeouts) assert.Equal(t, 0, e.Racing) assert.Equal(t, i, e.Wave) @@ -720,6 +734,7 @@ func testClientEventHandlerPanicEnsureCancelCalled(t *testing.T) { require.Panics(t, func() { _, _ = cl.Get("test") }) require.NotNil(t, e) assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) assert.Equal(t, 0, e.Racing) assert.Equal(t, 0, e.Wave) require.NotNil(t, e.Request) @@ -896,6 +911,8 @@ func testClientPlanCancel(t *testing.T) { assert.Same(t, context.Canceled, urlError.Err) assert.Same(t, err, e.Err) assert.Same(t, p, e.Plan) + assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) }) t.Run("plan cancelled after request", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) @@ -930,6 +947,8 @@ func testClientPlanCancel(t *testing.T) { assert.Same(t, context.Canceled, urlError.Err) assert.Same(t, err, e.Err) assert.Same(t, p, e.Plan) + assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) }) } @@ -974,6 +993,8 @@ func testClientPlanChange(t *testing.T) { var urlError *url.Error require.ErrorAs(t, err, &urlError) assert.Same(t, nonRetryableErr, urlError.Unwrap()) + assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) }) t.Run("to nil (panic)", func(t *testing.T) { doer := newMockHTTPDoer(t) @@ -1070,6 +1091,7 @@ func testClientRacingNeverStart(t *testing.T) { assert.NoError(t, e.Err) assert.Equal(t, 204, e.StatusCode()) assert.Equal(t, 0, e.Attempt) + assert.Equal(t, 1, e.AttemptEnds) assert.Equal(t, 0, e.Racing) assert.Equal(t, 0, e.Wave) assert.Equal(t, []string{ @@ -1124,10 +1146,10 @@ func testClientRacingRetry(t *testing.T) { Return(&http.Response{StatusCode: 200, Body: ioutil.NopCloser(strings.NewReader("racing-retry=3"))}, nil). Once() retryPolicy.On("Decide", mock.MatchedBy(func(e *request.Execution) bool { - return e.Wave == 0 && e.Attempt <= 2 + return e.Wave == 0 && e.Attempt <= 2 && e.AttemptEnds <= 3 })).Return(true).Times(3) retryPolicy.On("Decide", mock.MatchedBy(func(e *request.Execution) bool { - return e.Wave == 1 && e.Attempt == 3 + return e.Wave == 1 && e.Attempt == 3 && e.AttemptEnds == 4 })).Return(false).Once() retryPolicy.On("Wait", mock.Anything).Return(time.Nanosecond).Once() racingPolicy.On("Schedule", mock.MatchedBy(func(e *request.Execution) bool { @@ -1150,6 +1172,7 @@ func testClientRacingRetry(t *testing.T) { require.NoError(t, e.Err) assert.Equal(t, 1, e.Wave) assert.Equal(t, 3, e.Attempt) + assert.Equal(t, 4, e.AttemptEnds) n := 2 + 4*3 require.Len(t, trace.calls, n) assert.Equal(t, []string{ @@ -1331,7 +1354,10 @@ func testRacingPlanCancelDuringWaveLoop(t *testing.T) { assert.Same(t, err, e.Err) assert.Equal(t, 0, e.Wave) assert.GreaterOrEqual(t, e.Attempt, 0) + assert.GreaterOrEqual(t, e.AttemptEnds, 1) assert.LessOrEqual(t, e.Attempt, N+5) + assert.LessOrEqual(t, e.AttemptEnds, N+6) + assert.Less(t, e.Attempt, e.AttemptEnds) } func testClientRacingPanic(t *testing.T) { @@ -1566,8 +1592,11 @@ func testClientRacingMultipleWaves(t *testing.T) { assert.GreaterOrEqual(t, call, (numWaves-1)*minRacing) assert.Equal(t, e.Wave, numWaves-1) assert.GreaterOrEqual(t, e.Attempt, (numWaves-1)*minRacing) + assert.Greater(t, e.AttemptEnds, (numWaves-1)*minRacing) assert.GreaterOrEqual(t, e.Attempt, curWaveFirstAttempt) + assert.Greater(t, e.AttemptEnds, curWaveFirstAttempt) assert.Less(t, e.Attempt, numWaves*maxRacing) + assert.LessOrEqual(t, e.AttemptEnds, numWaves*maxRacing) } type mockHTTPDoer struct { diff --git a/request/execution.go b/request/execution.go index d1044ba..568f251 100644 --- a/request/execution.go +++ b/request/execution.go @@ -54,8 +54,8 @@ type Execution struct { // retries will have an attempt number of 2. Attempt int - // AttemptTimeouts is the count of the number of times an HTTP - // request attempt timed out during the execution. + // AttemptTimeouts is the count of how many HTTP request attempts + // timed out during the execution. // // Plan timeouts (when the plan's own context deadline is exceeded) // do not contribute to the attempt timeout counter, but if an @@ -120,6 +120,10 @@ type Execution struct { // during events relating to request attempts and zero otherwise. Racing int + // AttemptEnds is the count of how many HTTP request attempts have + // ended within the execution. + AttemptEnds int + // Data contains arbitrary user data. The httpx library will not // touch this field, and it will typically be nil unless used by // event handler writers. diff --git a/retry/decider.go b/retry/decider.go index fef62dd..070ed15 100644 --- a/retry/decider.go +++ b/retry/decider.go @@ -88,11 +88,11 @@ func (f DeciderFunc) Or(g DeciderFunc) DeciderFunc { } // Times constructs a retry decider which allows up to n retries. The -// returned decider returns true while the execution attempt index -// e.Attempt is less than n, and false otherwise. +// returned decider returns true while the number of finished attempts +// within the execution is less than or equal to n, and false otherwise. func Times(n int) DeciderFunc { return func(e *request.Execution) bool { - return e.Attempt < n + return e.AttemptEnds <= n } } diff --git a/retry/decider_test.go b/retry/decider_test.go index 539c258..bf6b8a8 100644 --- a/retry/decider_test.go +++ b/retry/decider_test.go @@ -29,9 +29,11 @@ func TestDefaultDecider(t *testing.T) { t.Run(fmt.Sprintf("codes[%d]=%d", i, code), func(t *testing.T) { for j := 0; j < DefaultTimes; j++ { e.Attempt = j + e.AttemptEnds = e.Attempt + 1 assert.True(t, DefaultDecider(&e), fmt.Sprintf("Expect true for attempt %d", j)) } e.Attempt = DefaultTimes + e.AttemptEnds = e.Attempt + 1 assert.False(t, DefaultDecider(&e), fmt.Sprintf("Expect false for attempt %d", e.Attempt)) }) } @@ -45,8 +47,10 @@ func TestDefaultDecider(t *testing.T) { } t.Run(fmt.Sprintf("codes[%d]=%d", i, code), func(t *testing.T) { e.Attempt = 0 + e.AttemptEnds = e.Attempt + 1 assert.False(t, DefaultDecider(&e), "Expect false for attempt 0") e.Attempt = 4 + e.AttemptEnds = e.Attempt + 1 assert.False(t, DefaultDecider(&e), "Expect false for attempt 4") }) } @@ -60,9 +64,11 @@ func TestDefaultDecider(t *testing.T) { t.Run(fmt.Sprintf("transientErrs[%d]=%v", i, te), func(t *testing.T) { for j := 0; j < DefaultTimes; j++ { e.Attempt = j + e.AttemptEnds = e.Attempt + 1 assert.True(t, DefaultDecider(&e), fmt.Sprintf("Expect true for attempt %d", j)) } e.Attempt = DefaultTimes + e.AttemptEnds = e.Attempt + 1 assert.False(t, DefaultDecider(&e), fmt.Sprintf("Expect false for attempt %d", e.Attempt)) }) } @@ -75,8 +81,10 @@ func TestDefaultDecider(t *testing.T) { } t.Run(fmt.Sprintf("nonTransientErrs[%d]=%v", i, nte), func(t *testing.T) { e.Attempt = 0 + e.AttemptEnds = e.Attempt + 1 assert.False(t, DefaultDecider(&e), "Expect false for attempt 0") e.Attempt = 4 + e.AttemptEnds = e.Attempt + 1 assert.False(t, DefaultDecider(&e), "Expect false for attempt 4") }) } @@ -131,13 +139,17 @@ func TestDeciderOr(t *testing.T) { func TestTimes(t *testing.T) { zero := Times(0) - assert.False(t, zero(&request.Execution{})) + assert.True(t, zero(&request.Execution{})) + assert.False(t, zero(&request.Execution{AttemptEnds: 1})) one := Times(1) assert.True(t, one(&request.Execution{})) - assert.False(t, one(&request.Execution{Attempt: 1})) + assert.True(t, one(&request.Execution{AttemptEnds: 1})) + assert.False(t, one(&request.Execution{AttemptEnds: 2})) two := Times(2) - assert.True(t, two(&request.Execution{Attempt: 1})) - assert.False(t, two(&request.Execution{Attempt: 2})) + assert.True(t, two(&request.Execution{})) + assert.True(t, two(&request.Execution{AttemptEnds: 1})) + assert.True(t, two(&request.Execution{AttemptEnds: 2})) + assert.False(t, two(&request.Execution{AttemptEnds: 3})) } func TestBefore(t *testing.T) { diff --git a/retry/policy_test.go b/retry/policy_test.go index 8fae5e9..7bcd21f 100644 --- a/retry/policy_test.go +++ b/retry/policy_test.go @@ -31,8 +31,8 @@ func TestDefault(t *testing.T) { })) } assert.False(t, DefaultPolicy.Decide(&request.Execution{ - Attempt: DefaultTimes, - Err: syscall.ETIMEDOUT, + AttemptEnds: DefaultTimes + 1, + Err: syscall.ETIMEDOUT, })) }) t.Run("Waiter", func(t *testing.T) { @@ -50,7 +50,8 @@ func TestDefault(t *testing.T) { } func TestNever(t *testing.T) { - assert.False(t, Never.Decide(&request.Execution{})) + assert.True(t, Never.Decide(&request.Execution{})) + assert.False(t, Never.Decide(&request.Execution{AttemptEnds: 1})) } func TestNewPolicy(t *testing.T) {