diff --git a/src/main/java/com/nike/cerberus/aws/sts/AwsStsClient.java b/src/main/java/com/nike/cerberus/aws/sts/AwsStsClient.java index f1a2dfe92..76c471b23 100644 --- a/src/main/java/com/nike/cerberus/aws/sts/AwsStsClient.java +++ b/src/main/java/com/nike/cerberus/aws/sts/AwsStsClient.java @@ -33,7 +33,7 @@ public AwsStsClient(AwsStsHttpClient httpClient) { } public GetCallerIdentityResponse getCallerIdentity(AwsStsHttpHeader header) { - GetCallerIdentityFullResponse response = httpClient.execute(header.generateHeaders(), GetCallerIdentityFullResponse.class); + GetCallerIdentityFullResponse response = httpClient.execute(header.getRegion(), header.generateHeaders(), GetCallerIdentityFullResponse.class); return response.getGetCallerIdentityResponse(); } } diff --git a/src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpClient.java b/src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpClient.java index b55981b79..cb4cbc74a 100644 --- a/src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpClient.java +++ b/src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpClient.java @@ -49,7 +49,7 @@ public class AwsStsHttpClient { private static final MediaType DEFAULT_ACCEPTED_MEDIA_TYPE = MediaType.parse("application/json"); - private static final String DEFAULT_AWS_STS_ENDPOINT = "https://sts.amazonaws.com"; + private static final String AWS_STS_ENDPOINT_TEMPLATE = "https://sts.%s.amazonaws.com"; private static final String DEFAULT_GET_CALLER_IDENTITY_ACTION = "Action=GetCallerIdentity&Version=2011-06-15"; @@ -72,14 +72,16 @@ public AwsStsHttpClient(@Named(AwsStsGuiceModule.AWS_STS_HTTP_CLIENT_NAME) final /** * Executes the HTTP request based on the input parameters. * - * @param headers HTTP Headers to include in the request + * @param region The region to call sts get caller identity in. + * @param headers HTTP Headers to include in the request * @param responseClass The class of the response object * @return Response from the server */ - public M execute(final Map headers, - final Class responseClass) { + public M execute(final String region, + final Map headers, + final Class responseClass) { try { - Request request = buildRequest(headers); + Request request = buildRequest(region, headers); Response response = executeRequestWithRetry(request, DEFAULT_AUTH_RETRIES, DEFAULT_RETRY_INTERVAL_IN_MILLIS); if (response.code() >= 400 && response.code() < 500) { ApiException.Builder builder = ApiException.newBuilder(); @@ -117,8 +119,8 @@ public M execute(final Map headers, * * @throws JsonProcessingException */ - protected Request buildRequest(Map headers) { - Request.Builder requestBuilder = new Request.Builder().url(DEFAULT_AWS_STS_ENDPOINT) + protected Request buildRequest(String region, Map headers) { + Request.Builder requestBuilder = new Request.Builder().url(String.format(AWS_STS_ENDPOINT_TEMPLATE, region)) .addHeader("Accept", DEFAULT_ACCEPTED_MEDIA_TYPE.toString()); if (headers != null) { diff --git a/src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpHeader.java b/src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpHeader.java index d708b1522..213a6e8ad 100644 --- a/src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpHeader.java +++ b/src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpHeader.java @@ -16,10 +16,15 @@ package com.nike.cerberus.aws.sts; +import com.amazonaws.regions.Regions; import com.google.common.base.Preconditions; import com.google.common.collect.Maps; +import com.nike.backstopper.exception.ApiException; +import com.nike.cerberus.error.DefaultApiError; import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static com.nike.cerberus.aws.sts.AwsStsHttpHeaders.HEADER_AUTHORIZATION; import static com.nike.cerberus.aws.sts.AwsStsHttpHeaders.HEADER_X_AMZ_DATE; @@ -51,6 +56,33 @@ public Map generateHeaders() { return headers; } + public String getRegion() { + Pattern pattern = Pattern.compile(".*Credential=.*?/\\d+/(?.*?)/.*"); + Matcher matcher = pattern.matcher(authorization); + boolean didMatch = matcher.matches(); + + if (!didMatch) { + throw ApiException.newBuilder() + .withApiErrors(DefaultApiError.GENERIC_BAD_REQUEST) + .withExceptionMessage(String.format("Failed to determine region from header %s.", authorization)) + .build(); + } + + String region = matcher.group("region"); + + try { + //noinspection ResultOfMethodCallIgnored + Regions.fromName(region); + } catch (IllegalArgumentException e) { + throw ApiException.newBuilder() + .withApiErrors(DefaultApiError.GENERIC_BAD_REQUEST) + .withExceptionMessage(String.format("Invalid region supplied %s.", region)) + .build(); + } + + return region; + } + public String getAmzDate() { return amzDate; } diff --git a/src/test/java/com/nike/cerberus/aws/sts/AwsStsClientTest.java b/src/test/java/com/nike/cerberus/aws/sts/AwsStsClientTest.java index a6ff44b51..7c2b28c64 100644 --- a/src/test/java/com/nike/cerberus/aws/sts/AwsStsClientTest.java +++ b/src/test/java/com/nike/cerberus/aws/sts/AwsStsClientTest.java @@ -33,8 +33,11 @@ public class AwsStsClientTest { public void setup() { httpClient = mock(AwsStsHttpClient.class); awsStsClient = new AwsStsClient(httpClient); - awsStsHttpHeader = new AwsStsHttpHeader("test amz date", - "test amz security token", "test authorization"); + awsStsHttpHeader = new AwsStsHttpHeader( + "test amz date", + "test amz security token", + "AWS4-HMAC-SHA256 Credential=ASIA5S2FQS2GYQLK5FFF/20180904/us-west-2/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ddb9417d2b9bfe6f8b03e31a8f5d8ab98e0f4alkj12312098asdf" + ); } @Test @@ -44,7 +47,7 @@ public void test_getCallerIdentity() { GetCallerIdentityFullResponse response = mock(GetCallerIdentityFullResponse.class); - when(httpClient.execute(awsStsHttpHeader.generateHeaders(), GetCallerIdentityFullResponse.class)) + when(httpClient.execute(awsStsHttpHeader.getRegion(), awsStsHttpHeader.generateHeaders(), GetCallerIdentityFullResponse.class)) .thenReturn(response); // invoke method under test @@ -57,7 +60,7 @@ private void setupMocks() { GetCallerIdentityFullResponse response = new GetCallerIdentityFullResponse(); - when(httpClient.execute(awsStsHttpHeader.generateHeaders(), GetCallerIdentityFullResponse.class)) + when(httpClient.execute(awsStsHttpHeader.getRegion(), awsStsHttpHeader.generateHeaders(), GetCallerIdentityFullResponse.class)) .thenReturn(response); } -} \ No newline at end of file +} diff --git a/src/test/java/com/nike/cerberus/aws/sts/AwsStsHttpClientTest.java b/src/test/java/com/nike/cerberus/aws/sts/AwsStsHttpClientTest.java index d59838336..676522729 100644 --- a/src/test/java/com/nike/cerberus/aws/sts/AwsStsHttpClientTest.java +++ b/src/test/java/com/nike/cerberus/aws/sts/AwsStsHttpClientTest.java @@ -66,7 +66,7 @@ public void test_execute() throws Exception { when(httpClient.newCall(any())).thenReturn(call); // invoke method under test - awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class); + awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class); } @Test(expected = ApiException.class) @@ -77,7 +77,7 @@ public void test_execute_handles_error() throws Exception { when(httpClient.newCall(any())).thenReturn(call); // invoke method under test - awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class); + awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class); } @Test @@ -85,10 +85,10 @@ public void test_buildRequest() throws Exception { String method = "POST"; // invoke method under test - Request request = awsStsHttpClient.buildRequest(Maps.newHashMap()); + Request request = awsStsHttpClient.buildRequest("us-west-2", Maps.newHashMap()); assertEquals(2, request.headers().size()); - assertEquals("https://sts.amazonaws.com/", request.url().uri().toString()); + assertEquals("https://sts.us-west-2.amazonaws.com/", request.url().uri().toString()); assertEquals(method, request.method()); assertEquals(43, request.body().contentLength()); } @@ -125,7 +125,7 @@ public void test_4xx_response() throws Exception { when(httpClient.newCall(any())).thenReturn(call); // invoke method under test - awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class); + awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class); } @Test @@ -137,7 +137,7 @@ public void test_does_not_retry_on_2xx() throws IOException { when(successCall.execute()).thenReturn(createFakeResponse(200, "test arn")); when(httpClient.newCall(any())).thenReturn(successCall).thenReturn(failCall); - GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class); + GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class); assertThat(actualResponse).isNotNull(); assertThat(actualResponse.getGetCallerIdentityResponse().getGetCallerIdentityResult().getArn()).isEqualToIgnoringCase("test arn"); } @@ -151,7 +151,7 @@ public void test_does_not_retry_on_4xx() throws IOException { when(successCall.execute()).thenReturn(createFakeResponse(200, "test arn")); when(httpClient.newCall(any())).thenReturn(failCall).thenReturn(successCall); - GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class); + GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class); assertThat(actualResponse).isNotNull(); assertThat(actualResponse.getGetCallerIdentityResponse().getGetCallerIdentityResult().getArn()).isEqualToIgnoringCase("test arn"); } @@ -167,7 +167,7 @@ public void test_retries_on_5xx_errors() throws IOException { when(successCall.execute()).thenReturn(createFakeResponse(200, "test arn")); when(httpClient.newCall(any())).thenReturn(failCall).thenReturn(failCall).thenReturn(successCall); - GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class); + GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class); assertThat(actualResponse).isNotNull(); assertThat(actualResponse.getGetCallerIdentityResponse().getGetCallerIdentityResult().getArn()).isEqualToIgnoringCase("test arn"); } diff --git a/src/test/java/com/nike/cerberus/aws/sts/AwsStsHttpHeaderTest.java b/src/test/java/com/nike/cerberus/aws/sts/AwsStsHttpHeaderTest.java new file mode 100644 index 000000000..bda3fb7a2 --- /dev/null +++ b/src/test/java/com/nike/cerberus/aws/sts/AwsStsHttpHeaderTest.java @@ -0,0 +1,28 @@ +package com.nike.cerberus.aws.sts; + +import org.junit.Test; + +import static junit.framework.TestCase.assertEquals; + +public class AwsStsHttpHeaderTest { + + @Test + public void test_getRegion_returns_region_as_expected() { + AwsStsHttpHeader header = new AwsStsHttpHeader( + "20180904T205115Z", + "FQoGZXIvYXdzEFYaDEYceadsfLKJLKlkj908098oB/rJIdxdo57fx3Ef2wW8WhFbSpLGg3hwNqhuepdkf/c0F7OXJutqM2yjgnZCiO7SPAdnMSJhoEgH7SJlkPaPfiRzZAf0yxxD6e4z0VJU74uQfbgfZpn5RL+JyDpgoYkUrjuyL8zRB1knGSOCi32Q75+asdfasd+7bWxMyJIKEb/HF2Le8xM/9F4WRqa5P0+asdfasdfasdf+MGlDlNG0KTzg1JT6QXf95ozWR5bBFSz5DbrFhXhMegMQ7+7Kvx+asdfasdl.jlkj++5NpRRlE54cct7+aG3HQskow9y73AU=", + "AWS4-HMAC-SHA256 Credential=ASIA5S2FQS2GYQLK5FFF/20180904/us-west-2/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ddb9417d2b9bfe6f8b03e31a8f5d8ab98e0f4alkj12312098asdf" + ); + + assertEquals("us-west-2", header.getRegion()); + + header = new AwsStsHttpHeader( + "20180904T205115Z", + "FQoGZXIvYXdzEFYaDEYceadsfLKJLKlkj908098oB/rJIdxdo57fx3Ef2wW8WhFbSpLGg3hwNqhuepdkf/c0F7OXJutqM2yjgnZCiO7SPAdnMSJhoEgH7SJlkPaPfiRzZAf0yxxD6e4z0VJU74uQfbgfZpn5RL+JyDpgoYkUrjuyL8zRB1knGSOCi32Q75+asdfasd+7bWxMyJIKEb/HF2Le8xM/9F4WRqa5P0+asdfasdfasdf+MGlDlNG0KTzg1JT6QXf95ozWR5bBFSz5DbrFhXhMegMQ7+7Kvx+asdfasdl.jlkj++5NpRRlE54cct7+aG3HQskow9y73AU=", + "AWS4-HMAC-SHA256 Credential=ASIA5S2FQS2GYQLK5FFF/20180904/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ddb9417d2b9bfe6f8b03e31a8f5d8ab98e0f4alkj12312098asdf" + ); + + assertEquals("us-east-1", header.getRegion()); + } + +}