Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
mocked credentials for tests, used data provider for test parameters,…
Browse files Browse the repository at this point in the history
… and incorporated provider chain debugger
  • Loading branch information
melanahammel committed Dec 17, 2018
1 parent 408fb75 commit a1044fa
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 27 deletions.
1 change: 1 addition & 0 deletions gradle/dependencies.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ dependencies {
testCompile "org.assertj:assertj-core:2.3.0"
testCompile "com.squareup.okhttp3:mockwebserver:3.7.0"
testCompile "commons-io:commons-io:2.4"
testCompile group: 'com.tngtech.java', name: 'junit-dataprovider', version: '1.10.0'
}

shadowJar {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
public class DefaultAWSCredentialsProviderChainDebugger {

private static final Logger LOGGER = LoggerFactory.getLogger(DefaultAWSCredentialsProviderChainDebugger.class);
private static final String TOKEN_IS_EXPIRED = "The security token included in the request is expired.";
private static final String TOKEN_IS_INVALID = "Invalid credentials";

/**
* This chain should match that found in DefaultAWSCredentialsProviderChain
Expand All @@ -31,10 +33,10 @@ public class DefaultAWSCredentialsProviderChainDebugger {

/**
* Log extra debugging information if appropriate
* @param serviceException exception from Amazon
* @param cerberusErrorMessage error message from Cerberus
*/
public void logExtraDebuggingIfAppropriate(AmazonServiceException serviceException) {
if (StringUtils.contains(serviceException.getMessage(), "The security token included in the request is invalid.")) {
public void logExtraDebuggingIfAppropriate(String cerberusErrorMessage) {
if (StringUtils.contains(cerberusErrorMessage, TOKEN_IS_EXPIRED) || StringUtils.contains(cerberusErrorMessage, TOKEN_IS_INVALID)) {
LOGGER.warn("Bad credentials may have been picked up from the DefaultAWSCredentialsProviderChain");
boolean firstCredentialsFound = false;
for (AWSCredentialsProvider provider : credentialProviderChain) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.amazonaws.DefaultRequest;
import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProviderChain;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.http.HttpMethodName;
import com.amazonaws.regions.Regions;
Expand Down Expand Up @@ -52,6 +53,8 @@ public class StsCerberusCredentialsProvider extends BaseAwsCredentialsProvider {

protected String regionName;

protected AWSCredentialsProviderChain providerChain;

private static final Logger LOGGER = LoggerFactory.getLogger(BaseAwsCredentialsProvider.class);

private final Gson gson = new GsonBuilder().setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES)
Expand Down Expand Up @@ -106,11 +109,36 @@ public StsCerberusCredentialsProvider(String cerberusUrl, String region, OkHttpC
}
}

/**
* Constructor to setup credentials provider with specified AWS credentials to sign request
*
* @param cerberusUrl Cerberus URL
* @param region AWS Region used in auth with Cerberus
* @param providerChain AWS Credentials Provider Chain
*/
public StsCerberusCredentialsProvider(String cerberusUrl, String region, AWSCredentialsProviderChain providerChain) {
super(cerberusUrl);

if (region != null ) {
regionName = Regions.fromName(region).getName();
} else {
throw new CerberusClientException("Region is null. Please provide valid AWS region.");
}

this.providerChain = providerChain;
}

/**
* Obtains AWS Credentials.
*/
private AWSCredentials getAWSCredentials(){
return DefaultAWSCredentialsProviderChain.getInstance().getCredentials();

if (providerChain == null) {
return DefaultAWSCredentialsProviderChain.getInstance().getCredentials();
}
else {
return providerChain.getCredentials();
}
}

/**
Expand Down Expand Up @@ -181,12 +209,14 @@ protected CerberusAuthResponse getToken(){
.build();

Response response = executeRequestWithRetry(request, DEFAULT_AUTH_RETRIES, DEFAULT_RETRY_INTERVAL_IN_MILLIS);
String responseBody = response.body().string();

if (response.code() != HttpStatus.OK) {
parseAndThrowErrorResponse(response.code(), response.body().string());
new DefaultAWSCredentialsProviderChainDebugger().logExtraDebuggingIfAppropriate(responseBody);
parseAndThrowErrorResponse(response.code(), responseBody);
}

return gson.fromJson(response.body().string(), CerberusAuthResponse.class);
return gson.fromJson(responseBody, CerberusAuthResponse.class);

} catch (IOException e) {
throw new CerberusClientException("I/O error while communicating with Cerberus", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@

package com.nike.cerberus.client.auth.aws;

import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProviderChain;
import com.amazonaws.auth.BasicSessionCredentials;
import com.nike.cerberus.client.CerberusClientException;
import com.nike.cerberus.client.model.CerberusAuthResponse;
import com.tngtech.java.junit.dataprovider.DataProviderRunner;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.apache.commons.lang3.StringUtils;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import com.tngtech.java.junit.dataprovider.DataProvider;

import java.io.IOException;
import java.util.Map;
Expand All @@ -34,49 +40,49 @@
/**
* Tests the StsCerberusCredentialsProvider class
*/
@RunWith(DataProviderRunner.class)
public class StsCerberusCredentialsProviderTest {

private static final String REGION_STRING_EAST = "us-east-1";
private static final String REGION_STRING_WEST = "us-west-2";

protected static final String AUTH_RESPONSE = "{\"auth_data\":\"eyJjbGllbnRfdG9rZW4iOiI2NjMyY2I1Zi1mMTBjLTQ1NzItOTU0NS1lNTJmNDdmNmEzZmQiLCAibGVhc2VfZHVyYXRpb24iOiIzNjAwIn0=\"}";
protected static final String DECODED_AUTH_DATA = "{\"client_token\":\"6632cb5f-f10c-4572-9545-e52f47f6a3fd\", \"lease_duration\":\"3600\"}";
protected static final String ERROR_RESPONSE = "Invalid credentials";

private String cerberusUrl;
private AWSCredentialsProviderChain chain;
private AWSCredentials credentials;

@Before
public void setUp() {

cerberusUrl = mock(String.class);

chain = mock(AWSCredentialsProviderChain.class);

credentials = new BasicSessionCredentials("foo", "bar", "cat");
}

@Test
public void test_sts_creds_provider_constructor() {

StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST);
assertThat(credentialsProvider.getCerberusUrl()).isEqualTo(cerberusUrl);
assertThat(credentialsProvider.regionName).isEqualTo(REGION_STRING_EAST);
}

@Test
public void test_get_signed_headers_for_east_region() throws IOException {
MockWebServer mockWebServer = new MockWebServer();
mockWebServer.start();
final String cerberusUrl = "http://localhost:" + mockWebServer.getPort();
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST);
@DataProvider(value = {
REGION_STRING_EAST,
REGION_STRING_WEST})
public void test_get_signed_headers(String testRegion) throws IOException {

Map<String, String> headers = credentialsProvider.getSignedHeaders();
assertThat(headers).isNotNull();
assertThat(headers.get("Authorization")).isNotEmpty();
assertThat(headers.get("X-Amz-Date")).isNotEmpty();
assertThat(headers.get("X-Amz-Security-Token")).isNotEmpty();
assertThat(headers.get("Host")).isNotEmpty();
}
when(chain.getCredentials()).thenReturn(credentials);

@Test
public void test_get_signed_headers_for_west_region() throws IOException {
MockWebServer mockWebServer = new MockWebServer();
mockWebServer.start();
final String cerberusUrl = "http://localhost:" + mockWebServer.getPort();
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_WEST);
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, testRegion, chain);

Map<String, String> headers = credentialsProvider.getSignedHeaders();
assertThat(headers).isNotNull();
Expand All @@ -86,12 +92,16 @@ public void test_get_signed_headers_for_west_region() throws IOException {
assertThat(headers.get("Host")).isNotEmpty();
}


@Test
public void get_token_returns_token() throws IOException {

when(chain.getCredentials()).thenReturn(credentials);

MockWebServer mockWebServer = new MockWebServer();
mockWebServer.start();
final String cerberusUrl = "http://localhost:" + mockWebServer.getPort();
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST);
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST, chain);

mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(DECODED_AUTH_DATA));
CerberusAuthResponse token = credentialsProvider.getToken();
Expand All @@ -101,10 +111,13 @@ public void get_token_returns_token() throws IOException {

@Test(expected = CerberusClientException.class)
public void get_token_throws_exception_timeout() throws IOException {

when(chain.getCredentials()).thenReturn(credentials);

MockWebServer mockWebServer = new MockWebServer();
mockWebServer.start();
final String cerberusUrl = "http://localhost:" + mockWebServer.getPort();
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST);
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST, chain);

CerberusAuthResponse token = credentialsProvider.getToken();
assertThat(token).isNotNull();
Expand All @@ -113,19 +126,23 @@ public void get_token_throws_exception_timeout() throws IOException {

@Test(expected = CerberusClientException.class)
public void get_token_throws_exception_when_url_is_blank(){
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST);

StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST, chain);
CerberusAuthResponse token = credentialsProvider.getToken();
assertThat(token).isNotNull();
assertThat(StringUtils.isNotEmpty(token.getClientToken()));
}

@Test(expected = CerberusClientException.class)
public void get_token_throws_exception_when_response_is_bad() throws IOException {

when(chain.getCredentials()).thenReturn(credentials);

MockWebServer mockWebServer = new MockWebServer();
mockWebServer.start();
final String cerberusUrl = "http://localhost:" + mockWebServer.getPort();
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST);
mockWebServer.enqueue(new MockResponse().setResponseCode(400).setBody(AUTH_RESPONSE));
StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST, chain);
mockWebServer.enqueue(new MockResponse().setResponseCode(400).setBody(ERROR_RESPONSE));

CerberusAuthResponse token = credentialsProvider.getToken();
assertThat(token).isNotNull();
Expand All @@ -134,6 +151,7 @@ public void get_token_throws_exception_when_response_is_bad() throws IOException

@Test(expected = CerberusClientException.class)
public void authenticate_throws_exception_when_token_is_null() {

StsCerberusCredentialsProvider credentialsProvider = new StsCerberusCredentialsProvider(cerberusUrl, REGION_STRING_EAST);
CerberusAuthResponse token = mock(CerberusAuthResponse.class);
when(token.getClientToken()).thenReturn(null);
Expand Down

0 comments on commit a1044fa

Please sign in to comment.