From 225e35863a70dcd525d8193622ed259f42c5122f Mon Sep 17 00:00:00 2001 From: Shaun Ford Date: Mon, 21 May 2018 10:20:44 -0700 Subject: [PATCH] Retry on 500 errors when authenticating or reading secrets (#37) * Retry on 500 errors when authenticating or reading secrets * Update tests and gradle version to 5.3.0 * Update javadoc --- gradle.properties | 2 +- .../nike/cerberus/client/CerberusClient.java | 51 ++++++++++++++- .../auth/aws/BaseAwsCredentialsProvider.java | 46 +++++++++++++- .../cerberus/client/CerberusClientTest.java | 63 ++++++++++++++++++- .../aws/BaseAwsCredentialsProviderTest.java | 59 +++++++++++++++++ 5 files changed, 215 insertions(+), 6 deletions(-) diff --git a/gradle.properties b/gradle.properties index 8133d23..5bd19d0 100644 --- a/gradle.properties +++ b/gradle.properties @@ -13,6 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -version=5.2.0 +version=5.3.0 groupId=com.nike artifactId=cerberus-client diff --git a/src/main/java/com/nike/cerberus/client/CerberusClient.java b/src/main/java/com/nike/cerberus/client/CerberusClient.java index fd1aae6..850ff32 100644 --- a/src/main/java/com/nike/cerberus/client/CerberusClient.java +++ b/src/main/java/com/nike/cerberus/client/CerberusClient.java @@ -51,6 +51,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; /** * Client for interacting with a Cerberus. @@ -63,6 +64,10 @@ public class CerberusClient { public static final MediaType DEFAULT_MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); + protected static final int DEFAULT_NUM_RETRIES = 3; + + protected static final int DEFAULT_RETRY_INTERVAL_IN_MILLIS = 200; + private final CerberusCredentialsProvider credentialsProvider; private final OkHttpClient httpClient; @@ -228,7 +233,7 @@ public CerberusResponse read(final String path) { final HttpUrl url = buildUrl(SECRET_PATH_PREFIX, path); logger.debug("read: requestUrl={}", url); - final Response response = execute(url, HttpMethod.GET, null); + final Response response = executeWithRetry(url, HttpMethod.GET, null, DEFAULT_NUM_RETRIES, DEFAULT_RETRY_INTERVAL_IN_MILLIS); if (response.code() != HttpStatus.OK) { parseAndThrowErrorResponse(response); @@ -433,6 +438,42 @@ protected HttpUrl buildUrl(final String prefix, final String path) { return HttpUrl.parse(baseUrl + prefix + path); } + /** + * Executes an HTTP request and retries if a 500 level error is returned + * @param url Full URL to which to make the HTTP request + * @param method HTTP Method (e.g. GET, PUT, POST) + * @param requestBody Body to add to the request. Nullable + * @param numRetries Maximum number of times to retry on 500 failures + * @param sleepIntervalInMillis Time in milliseconds to sleep between retries. Zero for no sleep + * @return Any HTTP response with status code below 500, or the last error response if only 500's are returned + */ + protected Response executeWithRetry(final HttpUrl url, + final String method, + final Object requestBody, + final int numRetries, + final int sleepIntervalInMillis) { + CerberusClientException exception = null; + Response response = null; + for(int retryNumber = 0; retryNumber < numRetries; retryNumber++) { + try { + response = execute(url, method, requestBody); + if (response.code() < 500) { + return response; + } + } catch (CerberusClientException cce) { + logger.debug(String.format("Failed to call %s %s. Retrying...", method, url), cce); + exception = cce; + } + sleep(sleepIntervalInMillis * (long) Math.pow(2, retryNumber)); + } + + if (exception != null) { + throw exception; + } else { + return response; + } + } + /** * Executes the HTTP request based on the input parameters. * @@ -632,4 +673,12 @@ protected byte[] responseBodyAsBytes(Response response) { throw new CerberusClientException("ERROR failed to print "); } } + + private void sleep(long milliseconds) { + try { + TimeUnit.MILLISECONDS.sleep(milliseconds); + } catch (InterruptedException ie) { + logger.warn("Sleep interval interrupted.", ie); + } + } } diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java index 2825f50..4d97ee4 100644 --- a/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java +++ b/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java @@ -58,6 +58,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantReadWriteLock; @@ -80,6 +81,10 @@ public abstract class BaseAwsCredentialsProvider implements CerberusCredentialsP private static final Logger LOGGER = LoggerFactory.getLogger(BaseAwsCredentialsProvider.class); + protected static final int DEFAULT_AUTH_RETRIES = 3; + + protected static final int DEFAULT_RETRY_INTERVAL_IN_MILLIS = 200; + private final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); private final Lock readLock = readWriteLock.readLock(); @@ -250,7 +255,7 @@ protected String getEncryptedAuthData(final String iamPrincipalArn, Region regio .addHeader(ClientVersion.CERBERUS_CLIENT_HEADER, cerberusJavaClientHeaderValue) .method(HttpMethod.POST, buildCredentialsRequestBody(iamPrincipalArn, region)); - Response response = httpClient.newCall(requestBuilder.build()).execute(); + Response response = executeRequestWithRetry(requestBuilder.build(), DEFAULT_AUTH_RETRIES, DEFAULT_RETRY_INTERVAL_IN_MILLIS); if (response.code() != HttpStatus.OK) { parseAndThrowErrorResponse(response.code(), response.body().string()); @@ -298,6 +303,45 @@ protected CerberusAuthResponse decryptToken(AWSKMS kmsClient, String encryptedTo return gson.fromJson(decryptedAuthData, CerberusAuthResponse.class); } + /** + * Executes an HTTP request and retries if a 500 level error is returned + * @param request The request to execute + * @param numRetries The maximum number of times to retry + * @param sleepIntervalInMillis Time in milliseconds to sleep between retries. Zero for no sleep. + * @return Any HTTP response with status code below 500, or the last error response if only 500's are returned + * @throws IOException If an IOException occurs during the last retry, then rethrow the error + */ + protected Response executeRequestWithRetry(Request request, int numRetries, int sleepIntervalInMillis) throws IOException { + IOException exception = null; + Response response = null; + for(int retryNumber = 0; retryNumber < numRetries; retryNumber++) { + try { + response = httpClient.newCall(request).execute(); + if (response.code() < 500) { + return response; + } + } catch (IOException ioe) { + LOGGER.debug(String.format("Failed to call %s %s. Retrying...", request.method(), request.url()), ioe); + exception = ioe; + } + sleep(sleepIntervalInMillis * (long) Math.pow(2, retryNumber)); + } + + if (exception != null) { + throw exception; + } else { + return response; + } + } + + private void sleep(long milliseconds) { + try { + TimeUnit.MILLISECONDS.sleep(milliseconds); + } catch (InterruptedException ie) { + LOGGER.warn("Sleep interval interrupted.", ie); + } + } + private RequestBody buildCredentialsRequestBody(final String iamPrincipalArn, Region region) { final String regionName = region == null ? Regions.getCurrentRegion().getName() : region.getName(); diff --git a/src/test/java/com/nike/cerberus/client/CerberusClientTest.java b/src/test/java/com/nike/cerberus/client/CerberusClientTest.java index fda4075..303f9fb 100644 --- a/src/test/java/com/nike/cerberus/client/CerberusClientTest.java +++ b/src/test/java/com/nike/cerberus/client/CerberusClientTest.java @@ -19,10 +19,9 @@ import com.nike.cerberus.client.auth.DefaultCerberusCredentialsProviderChain; import com.nike.cerberus.client.auth.CerberusCredentials; import com.nike.cerberus.client.auth.CerberusCredentialsProvider; -import com.nike.cerberus.client.http.HttpStatus; -import com.nike.cerberus.client.model.CerberusClientTokenResponse; import com.nike.cerberus.client.model.CerberusListResponse; import com.nike.cerberus.client.model.CerberusResponse; +import okhttp3.Call; import okhttp3.Headers; import okhttp3.HttpUrl; import okhttp3.OkHttpClient; @@ -42,9 +41,13 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import static com.nike.cerberus.client.CerberusClient.DEFAULT_NUM_RETRIES; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -135,6 +138,60 @@ public void read_returns_map_of_data_for_specified_path_if_exists() throws IOExc assertThat(cerberusResponse.getData().get("value")).isEqualToIgnoringCase("world"); } + @Test + public void read_does_not_retry_on_200() throws IOException { + mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson("secret"))); + mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody(getResponseJson("error"))); + + CerberusResponse cerberusResponse = cerberusClient.read("app/api-key"); + + assertThat(cerberusResponse).isNotNull(); + assertThat(cerberusResponse.getData().containsKey("value")).isTrue(); + assertThat(cerberusResponse.getData().get("value")).isEqualToIgnoringCase("world"); + } + + @Test + public void read_retries_on_500_errors() throws IOException { + for (int i = 0; i < DEFAULT_NUM_RETRIES - 1; i++) { + mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody(getResponseJson("error"))); + } + mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson("secret"))); + + CerberusResponse cerberusResponse = cerberusClient.read("app/api-key"); + + assertThat(cerberusResponse).isNotNull(); + assertThat(cerberusResponse.getData().containsKey("value")).isTrue(); + assertThat(cerberusResponse.getData().get("value")).isEqualToIgnoringCase("world"); + } + + @Test + public void read_retries_on_IOException() throws IOException { + UrlResolver urlResolver = mock(UrlResolver.class); + when(urlResolver.resolve()).thenReturn("http://localhost:" + mockWebServer.getPort()); + + OkHttpClient httpClient = mock(OkHttpClient.class); + Call call = mock(Call.class); + when(call.execute()).thenThrow(new IOException()); + when(httpClient.newCall(any(Request.class))).thenReturn(call); + final CerberusCredentialsProvider cerberusCredentialsProvider = mock(CerberusCredentialsProvider.class); + when(cerberusCredentialsProvider.getCredentials()).thenReturn(new TestCerberusCredentials()); + + CerberusClient cerberusClient = new CerberusClient(urlResolver, cerberusCredentialsProvider, httpClient); + try { + cerberusClient.read("app/api-key"); + + // code should not reach this point, throw an error if it does + throw new AssertionError("Expected CerberusClientException, but was not thrown"); + } catch(CerberusClientException cce) { // catch this error so that the remaining tests will run + // ensure that error is thrown because of mocked IOException + if ( !(cce.getCause() instanceof IOException) ) { + throw new AssertionError("Expected error cause to be IOException, but was " + cce.getCause().getClass()); + } + } + + verify(httpClient, times(DEFAULT_NUM_RETRIES)).newCall(any(Request.class)); + } + @Test public void read_throws_cerberus_server_exception_if_response_is_not_ok() { final MockResponse response = new MockResponse(); @@ -156,7 +213,7 @@ public void read_throws_runtime_exception_if_unexpected_error_encountered() thro final String cerberusUrl = "http://localhost:" + serverSocket.getLocalPort(); final CerberusCredentialsProvider cerberusCredentialsProvider = mock(CerberusCredentialsProvider.class); final OkHttpClient httpClient = buildHttpClient(1, TimeUnit.SECONDS); - cerberusClient = new CerberusClient(new StaticCerberusUrlResolver(cerberusUrl), cerberusCredentialsProvider, httpClient); + CerberusClient cerberusClient = new CerberusClient(new StaticCerberusUrlResolver(cerberusUrl), cerberusCredentialsProvider, httpClient); when(cerberusCredentialsProvider.getCredentials()).thenReturn(new TestCerberusCredentials()); diff --git a/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java b/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java index 0d1b2fd..e36780f 100644 --- a/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java +++ b/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java @@ -23,6 +23,9 @@ import com.nike.cerberus.client.CerberusServerException; import com.nike.cerberus.client.DefaultCerberusUrlResolver; import com.nike.cerberus.client.UrlResolver; +import okhttp3.Call; +import okhttp3.OkHttpClient; +import okhttp3.Request; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.After; @@ -31,7 +34,11 @@ import java.io.IOException; +import static com.nike.cerberus.client.auth.aws.BaseAwsCredentialsProvider.DEFAULT_AUTH_RETRIES; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.powermock.api.mockito.PowerMockito.mock; import static org.powermock.api.mockito.PowerMockito.when; @@ -39,6 +46,7 @@ public class BaseAwsCredentialsProviderTest extends BaseCredentialsProviderTest{ public static final Region REGION = RegionUtils.getRegion("us-west-2"); public static final String CERBERUS_TEST_ARN = "arn:aws:iam::123456789012:role/cerberus-test-role"; public static final String ERROR_RESPONSE = "Error calling cerberus"; + public static final String SUCCESS_RESPONSE = "{\"auth_data\": \"\"}"; protected static final String MISSING_AUTH_DATA = "{}"; @@ -97,6 +105,53 @@ public void getEncryptedAuthData_throws_exception_on_missing_auth_data() throws provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION); } + @Test + public void test_that_getEncryptedAuthData_retries_on_500_errors() { + when(urlResolver.resolve()).thenReturn(cerberusUrl); + + for (int i = 0; i < DEFAULT_AUTH_RETRIES - 1; i++) { + mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody(ERROR_RESPONSE)); + } + mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(SUCCESS_RESPONSE)); + + provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION); + } + + @Test + public void test_that_getEncryptedAuthData_retries_on_IOException_errors() throws IOException { + when(urlResolver.resolve()).thenReturn(cerberusUrl); + + OkHttpClient httpClient = mock(OkHttpClient.class); + Call call = mock(Call.class); + when(call.execute()).thenThrow(new IOException()); + when(httpClient.newCall(any(Request.class))).thenReturn(call); + TestAwsCredentialsProvider provider = new TestAwsCredentialsProvider(urlResolver, httpClient); + + try { + provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION); + + // code should not reach this point, throw an error if it does + throw new AssertionError("Expected CerberusClientException, but was not thrown"); + } catch(CerberusClientException cce) { // catch this error so that the remaining tests will run + // ensure that error is thrown because of mocked IOException + if ( !(cce.getCause() instanceof IOException) ) { + throw new AssertionError("Expected error cause to be IOException, but was " + cce.getCause().getClass()); + } + } + + verify(httpClient, times(DEFAULT_AUTH_RETRIES)).newCall(any(Request.class)); + } + + @Test + public void test_that_getEncryptedAuthData_does_not_retry_on_200() { + when(urlResolver.resolve()).thenReturn(cerberusUrl); + + mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(SUCCESS_RESPONSE)); + mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody(ERROR_RESPONSE)); + + provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION); + } + class TestAwsCredentialsProvider extends BaseAwsCredentialsProvider { /** * Constructor to setup credentials provider using the specified @@ -108,6 +163,10 @@ public TestAwsCredentialsProvider(UrlResolver urlResolver) { super(urlResolver); } + public TestAwsCredentialsProvider(UrlResolver urlResolver, OkHttpClient client) { + super(urlResolver, client); + } + @Override protected void authenticate() {