Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Retry on 500 errors when authenticating or reading secrets (#37)
Browse files Browse the repository at this point in the history
* Retry on 500 errors when authenticating or reading secrets

* Update tests and gradle version to 5.3.0

* Update javadoc
  • Loading branch information
sdford authored May 21, 2018
1 parent bb02d72 commit 225e358
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 6 deletions.
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
version=5.2.0
version=5.3.0
groupId=com.nike
artifactId=cerberus-client
51 changes: 50 additions & 1 deletion src/main/java/com/nike/cerberus/client/CerberusClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
* Client for interacting with a Cerberus.
Expand All @@ -63,6 +64,10 @@ public class CerberusClient {

public static final MediaType DEFAULT_MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8");

protected static final int DEFAULT_NUM_RETRIES = 3;

protected static final int DEFAULT_RETRY_INTERVAL_IN_MILLIS = 200;

private final CerberusCredentialsProvider credentialsProvider;

private final OkHttpClient httpClient;
Expand Down Expand Up @@ -228,7 +233,7 @@ public CerberusResponse read(final String path) {
final HttpUrl url = buildUrl(SECRET_PATH_PREFIX, path);
logger.debug("read: requestUrl={}", url);

final Response response = execute(url, HttpMethod.GET, null);
final Response response = executeWithRetry(url, HttpMethod.GET, null, DEFAULT_NUM_RETRIES, DEFAULT_RETRY_INTERVAL_IN_MILLIS);

if (response.code() != HttpStatus.OK) {
parseAndThrowErrorResponse(response);
Expand Down Expand Up @@ -433,6 +438,42 @@ protected HttpUrl buildUrl(final String prefix, final String path) {
return HttpUrl.parse(baseUrl + prefix + path);
}

/**
* Executes an HTTP request and retries if a 500 level error is returned
* @param url Full URL to which to make the HTTP request
* @param method HTTP Method (e.g. GET, PUT, POST)
* @param requestBody Body to add to the request. Nullable
* @param numRetries Maximum number of times to retry on 500 failures
* @param sleepIntervalInMillis Time in milliseconds to sleep between retries. Zero for no sleep
* @return Any HTTP response with status code below 500, or the last error response if only 500's are returned
*/
protected Response executeWithRetry(final HttpUrl url,
final String method,
final Object requestBody,
final int numRetries,
final int sleepIntervalInMillis) {
CerberusClientException exception = null;
Response response = null;
for(int retryNumber = 0; retryNumber < numRetries; retryNumber++) {
try {
response = execute(url, method, requestBody);
if (response.code() < 500) {
return response;
}
} catch (CerberusClientException cce) {
logger.debug(String.format("Failed to call %s %s. Retrying...", method, url), cce);
exception = cce;
}
sleep(sleepIntervalInMillis * (long) Math.pow(2, retryNumber));
}

if (exception != null) {
throw exception;
} else {
return response;
}
}

/**
* Executes the HTTP request based on the input parameters.
*
Expand Down Expand Up @@ -632,4 +673,12 @@ protected byte[] responseBodyAsBytes(Response response) {
throw new CerberusClientException("ERROR failed to print ");
}
}

private void sleep(long milliseconds) {
try {
TimeUnit.MILLISECONDS.sleep(milliseconds);
} catch (InterruptedException ie) {
logger.warn("Sleep interval interrupted.", ie);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

Expand All @@ -80,6 +81,10 @@ public abstract class BaseAwsCredentialsProvider implements CerberusCredentialsP

private static final Logger LOGGER = LoggerFactory.getLogger(BaseAwsCredentialsProvider.class);

protected static final int DEFAULT_AUTH_RETRIES = 3;

protected static final int DEFAULT_RETRY_INTERVAL_IN_MILLIS = 200;

private final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock();

private final Lock readLock = readWriteLock.readLock();
Expand Down Expand Up @@ -250,7 +255,7 @@ protected String getEncryptedAuthData(final String iamPrincipalArn, Region regio
.addHeader(ClientVersion.CERBERUS_CLIENT_HEADER, cerberusJavaClientHeaderValue)
.method(HttpMethod.POST, buildCredentialsRequestBody(iamPrincipalArn, region));

Response response = httpClient.newCall(requestBuilder.build()).execute();
Response response = executeRequestWithRetry(requestBuilder.build(), DEFAULT_AUTH_RETRIES, DEFAULT_RETRY_INTERVAL_IN_MILLIS);

if (response.code() != HttpStatus.OK) {
parseAndThrowErrorResponse(response.code(), response.body().string());
Expand Down Expand Up @@ -298,6 +303,45 @@ protected CerberusAuthResponse decryptToken(AWSKMS kmsClient, String encryptedTo
return gson.fromJson(decryptedAuthData, CerberusAuthResponse.class);
}

/**
* Executes an HTTP request and retries if a 500 level error is returned
* @param request The request to execute
* @param numRetries The maximum number of times to retry
* @param sleepIntervalInMillis Time in milliseconds to sleep between retries. Zero for no sleep.
* @return Any HTTP response with status code below 500, or the last error response if only 500's are returned
* @throws IOException If an IOException occurs during the last retry, then rethrow the error
*/
protected Response executeRequestWithRetry(Request request, int numRetries, int sleepIntervalInMillis) throws IOException {
IOException exception = null;
Response response = null;
for(int retryNumber = 0; retryNumber < numRetries; retryNumber++) {
try {
response = httpClient.newCall(request).execute();
if (response.code() < 500) {
return response;
}
} catch (IOException ioe) {
LOGGER.debug(String.format("Failed to call %s %s. Retrying...", request.method(), request.url()), ioe);
exception = ioe;
}
sleep(sleepIntervalInMillis * (long) Math.pow(2, retryNumber));
}

if (exception != null) {
throw exception;
} else {
return response;
}
}

private void sleep(long milliseconds) {
try {
TimeUnit.MILLISECONDS.sleep(milliseconds);
} catch (InterruptedException ie) {
LOGGER.warn("Sleep interval interrupted.", ie);
}
}

private RequestBody buildCredentialsRequestBody(final String iamPrincipalArn, Region region) {
final String regionName = region == null ? Regions.getCurrentRegion().getName() : region.getName();

Expand Down
63 changes: 60 additions & 3 deletions src/test/java/com/nike/cerberus/client/CerberusClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
import com.nike.cerberus.client.auth.DefaultCerberusCredentialsProviderChain;
import com.nike.cerberus.client.auth.CerberusCredentials;
import com.nike.cerberus.client.auth.CerberusCredentialsProvider;
import com.nike.cerberus.client.http.HttpStatus;
import com.nike.cerberus.client.model.CerberusClientTokenResponse;
import com.nike.cerberus.client.model.CerberusListResponse;
import com.nike.cerberus.client.model.CerberusResponse;
import okhttp3.Call;
import okhttp3.Headers;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
Expand All @@ -42,9 +41,13 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;

import static com.nike.cerberus.client.CerberusClient.DEFAULT_NUM_RETRIES;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
Expand Down Expand Up @@ -135,6 +138,60 @@ public void read_returns_map_of_data_for_specified_path_if_exists() throws IOExc
assertThat(cerberusResponse.getData().get("value")).isEqualToIgnoringCase("world");
}

@Test
public void read_does_not_retry_on_200() throws IOException {
mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson("secret")));
mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody(getResponseJson("error")));

CerberusResponse cerberusResponse = cerberusClient.read("app/api-key");

assertThat(cerberusResponse).isNotNull();
assertThat(cerberusResponse.getData().containsKey("value")).isTrue();
assertThat(cerberusResponse.getData().get("value")).isEqualToIgnoringCase("world");
}

@Test
public void read_retries_on_500_errors() throws IOException {
for (int i = 0; i < DEFAULT_NUM_RETRIES - 1; i++) {
mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody(getResponseJson("error")));
}
mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson("secret")));

CerberusResponse cerberusResponse = cerberusClient.read("app/api-key");

assertThat(cerberusResponse).isNotNull();
assertThat(cerberusResponse.getData().containsKey("value")).isTrue();
assertThat(cerberusResponse.getData().get("value")).isEqualToIgnoringCase("world");
}

@Test
public void read_retries_on_IOException() throws IOException {
UrlResolver urlResolver = mock(UrlResolver.class);
when(urlResolver.resolve()).thenReturn("http://localhost:" + mockWebServer.getPort());

OkHttpClient httpClient = mock(OkHttpClient.class);
Call call = mock(Call.class);
when(call.execute()).thenThrow(new IOException());
when(httpClient.newCall(any(Request.class))).thenReturn(call);
final CerberusCredentialsProvider cerberusCredentialsProvider = mock(CerberusCredentialsProvider.class);
when(cerberusCredentialsProvider.getCredentials()).thenReturn(new TestCerberusCredentials());

CerberusClient cerberusClient = new CerberusClient(urlResolver, cerberusCredentialsProvider, httpClient);
try {
cerberusClient.read("app/api-key");

// code should not reach this point, throw an error if it does
throw new AssertionError("Expected CerberusClientException, but was not thrown");
} catch(CerberusClientException cce) { // catch this error so that the remaining tests will run
// ensure that error is thrown because of mocked IOException
if ( !(cce.getCause() instanceof IOException) ) {
throw new AssertionError("Expected error cause to be IOException, but was " + cce.getCause().getClass());
}
}

verify(httpClient, times(DEFAULT_NUM_RETRIES)).newCall(any(Request.class));
}

@Test
public void read_throws_cerberus_server_exception_if_response_is_not_ok() {
final MockResponse response = new MockResponse();
Expand All @@ -156,7 +213,7 @@ public void read_throws_runtime_exception_if_unexpected_error_encountered() thro
final String cerberusUrl = "http://localhost:" + serverSocket.getLocalPort();
final CerberusCredentialsProvider cerberusCredentialsProvider = mock(CerberusCredentialsProvider.class);
final OkHttpClient httpClient = buildHttpClient(1, TimeUnit.SECONDS);
cerberusClient = new CerberusClient(new StaticCerberusUrlResolver(cerberusUrl), cerberusCredentialsProvider, httpClient);
CerberusClient cerberusClient = new CerberusClient(new StaticCerberusUrlResolver(cerberusUrl), cerberusCredentialsProvider, httpClient);

when(cerberusCredentialsProvider.getCredentials()).thenReturn(new TestCerberusCredentials());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import com.nike.cerberus.client.CerberusServerException;
import com.nike.cerberus.client.DefaultCerberusUrlResolver;
import com.nike.cerberus.client.UrlResolver;
import okhttp3.Call;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.After;
Expand All @@ -31,14 +34,19 @@

import java.io.IOException;

import static com.nike.cerberus.client.auth.aws.BaseAwsCredentialsProvider.DEFAULT_AUTH_RETRIES;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.powermock.api.mockito.PowerMockito.mock;
import static org.powermock.api.mockito.PowerMockito.when;

public class BaseAwsCredentialsProviderTest extends BaseCredentialsProviderTest{
public static final Region REGION = RegionUtils.getRegion("us-west-2");
public static final String CERBERUS_TEST_ARN = "arn:aws:iam::123456789012:role/cerberus-test-role";
public static final String ERROR_RESPONSE = "Error calling cerberus";
public static final String SUCCESS_RESPONSE = "{\"auth_data\": \"\"}";

protected static final String MISSING_AUTH_DATA = "{}";

Expand Down Expand Up @@ -97,6 +105,53 @@ public void getEncryptedAuthData_throws_exception_on_missing_auth_data() throws
provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION);
}

@Test
public void test_that_getEncryptedAuthData_retries_on_500_errors() {
when(urlResolver.resolve()).thenReturn(cerberusUrl);

for (int i = 0; i < DEFAULT_AUTH_RETRIES - 1; i++) {
mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody(ERROR_RESPONSE));
}
mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(SUCCESS_RESPONSE));

provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION);
}

@Test
public void test_that_getEncryptedAuthData_retries_on_IOException_errors() throws IOException {
when(urlResolver.resolve()).thenReturn(cerberusUrl);

OkHttpClient httpClient = mock(OkHttpClient.class);
Call call = mock(Call.class);
when(call.execute()).thenThrow(new IOException());
when(httpClient.newCall(any(Request.class))).thenReturn(call);
TestAwsCredentialsProvider provider = new TestAwsCredentialsProvider(urlResolver, httpClient);

try {
provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION);

// code should not reach this point, throw an error if it does
throw new AssertionError("Expected CerberusClientException, but was not thrown");
} catch(CerberusClientException cce) { // catch this error so that the remaining tests will run
// ensure that error is thrown because of mocked IOException
if ( !(cce.getCause() instanceof IOException) ) {
throw new AssertionError("Expected error cause to be IOException, but was " + cce.getCause().getClass());
}
}

verify(httpClient, times(DEFAULT_AUTH_RETRIES)).newCall(any(Request.class));
}

@Test
public void test_that_getEncryptedAuthData_does_not_retry_on_200() {
when(urlResolver.resolve()).thenReturn(cerberusUrl);

mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(SUCCESS_RESPONSE));
mockWebServer.enqueue(new MockResponse().setResponseCode(500).setBody(ERROR_RESPONSE));

provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION);
}

class TestAwsCredentialsProvider extends BaseAwsCredentialsProvider {
/**
* Constructor to setup credentials provider using the specified
Expand All @@ -108,6 +163,10 @@ public TestAwsCredentialsProvider(UrlResolver urlResolver) {
super(urlResolver);
}

public TestAwsCredentialsProvider(UrlResolver urlResolver, OkHttpClient client) {
super(urlResolver, client);
}

@Override
protected void authenticate() {

Expand Down

0 comments on commit 225e358

Please sign in to comment.