Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Make CMS use the region that the request was signed for rather than l…
Browse files Browse the repository at this point in the history
…ocking into a single region (#176)
  • Loading branch information
fieldju authored Sep 4, 2018
1 parent 0835a17 commit 2d70260
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/main/java/com/nike/cerberus/aws/sts/AwsStsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public AwsStsClient(AwsStsHttpClient httpClient) {
}

public GetCallerIdentityResponse getCallerIdentity(AwsStsHttpHeader header) {
GetCallerIdentityFullResponse response = httpClient.execute(header.generateHeaders(), GetCallerIdentityFullResponse.class);
GetCallerIdentityFullResponse response = httpClient.execute(header.getRegion(), header.generateHeaders(), GetCallerIdentityFullResponse.class);
return response.getGetCallerIdentityResponse();
}
}
16 changes: 9 additions & 7 deletions src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class AwsStsHttpClient {

private static final MediaType DEFAULT_ACCEPTED_MEDIA_TYPE = MediaType.parse("application/json");

private static final String DEFAULT_AWS_STS_ENDPOINT = "https://sts.amazonaws.com";
private static final String AWS_STS_ENDPOINT_TEMPLATE = "https://sts.%s.amazonaws.com";

private static final String DEFAULT_GET_CALLER_IDENTITY_ACTION = "Action=GetCallerIdentity&Version=2011-06-15";

Expand All @@ -72,14 +72,16 @@ public AwsStsHttpClient(@Named(AwsStsGuiceModule.AWS_STS_HTTP_CLIENT_NAME) final
/**
* Executes the HTTP request based on the input parameters.
*
* @param headers HTTP Headers to include in the request
* @param region The region to call sts get caller identity in.
* @param headers HTTP Headers to include in the request
* @param responseClass The class of the response object
* @return Response from the server
*/
public <M> M execute(final Map<String, String> headers,
final Class<M> responseClass) {
public <M> M execute(final String region,
final Map<String, String> headers,
final Class<M> responseClass) {
try {
Request request = buildRequest(headers);
Request request = buildRequest(region, headers);
Response response = executeRequestWithRetry(request, DEFAULT_AUTH_RETRIES, DEFAULT_RETRY_INTERVAL_IN_MILLIS);
if (response.code() >= 400 && response.code() < 500) {
ApiException.Builder builder = ApiException.newBuilder();
Expand Down Expand Up @@ -117,8 +119,8 @@ public <M> M execute(final Map<String, String> headers,
*
* @throws JsonProcessingException
*/
protected Request buildRequest(Map<String, String> headers) {
Request.Builder requestBuilder = new Request.Builder().url(DEFAULT_AWS_STS_ENDPOINT)
protected Request buildRequest(String region, Map<String, String> headers) {
Request.Builder requestBuilder = new Request.Builder().url(String.format(AWS_STS_ENDPOINT_TEMPLATE, region))
.addHeader("Accept", DEFAULT_ACCEPTED_MEDIA_TYPE.toString());

if (headers != null) {
Expand Down
32 changes: 32 additions & 0 deletions src/main/java/com/nike/cerberus/aws/sts/AwsStsHttpHeader.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@

package com.nike.cerberus.aws.sts;

import com.amazonaws.regions.Regions;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import com.nike.backstopper.exception.ApiException;
import com.nike.cerberus.error.DefaultApiError;

import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.nike.cerberus.aws.sts.AwsStsHttpHeaders.HEADER_AUTHORIZATION;
import static com.nike.cerberus.aws.sts.AwsStsHttpHeaders.HEADER_X_AMZ_DATE;
Expand Down Expand Up @@ -51,6 +56,33 @@ public Map<String, String> generateHeaders() {
return headers;
}

public String getRegion() {
Pattern pattern = Pattern.compile(".*Credential=.*?/\\d+/(?<region>.*?)/.*");
Matcher matcher = pattern.matcher(authorization);
boolean didMatch = matcher.matches();

if (!didMatch) {
throw ApiException.newBuilder()
.withApiErrors(DefaultApiError.GENERIC_BAD_REQUEST)
.withExceptionMessage(String.format("Failed to determine region from header %s.", authorization))
.build();
}

String region = matcher.group("region");

try {
//noinspection ResultOfMethodCallIgnored
Regions.fromName(region);
} catch (IllegalArgumentException e) {
throw ApiException.newBuilder()
.withApiErrors(DefaultApiError.GENERIC_BAD_REQUEST)
.withExceptionMessage(String.format("Invalid region supplied %s.", region))
.build();
}

return region;
}

public String getAmzDate() {
return amzDate;
}
Expand Down
13 changes: 8 additions & 5 deletions src/test/java/com/nike/cerberus/aws/sts/AwsStsClientTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ public class AwsStsClientTest {
public void setup() {
httpClient = mock(AwsStsHttpClient.class);
awsStsClient = new AwsStsClient(httpClient);
awsStsHttpHeader = new AwsStsHttpHeader("test amz date",
"test amz security token", "test authorization");
awsStsHttpHeader = new AwsStsHttpHeader(
"test amz date",
"test amz security token",
"AWS4-HMAC-SHA256 Credential=ASIA5S2FQS2GYQLK5FFF/20180904/us-west-2/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ddb9417d2b9bfe6f8b03e31a8f5d8ab98e0f4alkj12312098asdf"
);
}

@Test
Expand All @@ -44,7 +47,7 @@ public void test_getCallerIdentity() {

GetCallerIdentityFullResponse response = mock(GetCallerIdentityFullResponse.class);

when(httpClient.execute(awsStsHttpHeader.generateHeaders(), GetCallerIdentityFullResponse.class))
when(httpClient.execute(awsStsHttpHeader.getRegion(), awsStsHttpHeader.generateHeaders(), GetCallerIdentityFullResponse.class))
.thenReturn(response);

// invoke method under test
Expand All @@ -57,7 +60,7 @@ private void setupMocks() {

GetCallerIdentityFullResponse response = new GetCallerIdentityFullResponse();

when(httpClient.execute(awsStsHttpHeader.generateHeaders(), GetCallerIdentityFullResponse.class))
when(httpClient.execute(awsStsHttpHeader.getRegion(), awsStsHttpHeader.generateHeaders(), GetCallerIdentityFullResponse.class))
.thenReturn(response);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void test_execute() throws Exception {
when(httpClient.newCall(any())).thenReturn(call);

// invoke method under test
awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class);
awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class);
}

@Test(expected = ApiException.class)
Expand All @@ -77,18 +77,18 @@ public void test_execute_handles_error() throws Exception {
when(httpClient.newCall(any())).thenReturn(call);

// invoke method under test
awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class);
awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class);
}

@Test
public void test_buildRequest() throws Exception {
String method = "POST";

// invoke method under test
Request request = awsStsHttpClient.buildRequest(Maps.newHashMap());
Request request = awsStsHttpClient.buildRequest("us-west-2", Maps.newHashMap());

assertEquals(2, request.headers().size());
assertEquals("https://sts.amazonaws.com/", request.url().uri().toString());
assertEquals("https://sts.us-west-2.amazonaws.com/", request.url().uri().toString());
assertEquals(method, request.method());
assertEquals(43, request.body().contentLength());
}
Expand Down Expand Up @@ -125,7 +125,7 @@ public void test_4xx_response() throws Exception {
when(httpClient.newCall(any())).thenReturn(call);

// invoke method under test
awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class);
awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class);
}

@Test
Expand All @@ -137,7 +137,7 @@ public void test_does_not_retry_on_2xx() throws IOException {
when(successCall.execute()).thenReturn(createFakeResponse(200, "test arn"));

when(httpClient.newCall(any())).thenReturn(successCall).thenReturn(failCall);
GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class);
GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class);
assertThat(actualResponse).isNotNull();
assertThat(actualResponse.getGetCallerIdentityResponse().getGetCallerIdentityResult().getArn()).isEqualToIgnoringCase("test arn");
}
Expand All @@ -151,7 +151,7 @@ public void test_does_not_retry_on_4xx() throws IOException {
when(successCall.execute()).thenReturn(createFakeResponse(200, "test arn"));

when(httpClient.newCall(any())).thenReturn(failCall).thenReturn(successCall);
GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class);
GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class);
assertThat(actualResponse).isNotNull();
assertThat(actualResponse.getGetCallerIdentityResponse().getGetCallerIdentityResult().getArn()).isEqualToIgnoringCase("test arn");
}
Expand All @@ -167,7 +167,7 @@ public void test_retries_on_5xx_errors() throws IOException {
when(successCall.execute()).thenReturn(createFakeResponse(200, "test arn"));

when(httpClient.newCall(any())).thenReturn(failCall).thenReturn(failCall).thenReturn(successCall);
GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, GetCallerIdentityFullResponse.class);
GetCallerIdentityFullResponse actualResponse = awsStsHttpClient.execute(null, null, GetCallerIdentityFullResponse.class);
assertThat(actualResponse).isNotNull();
assertThat(actualResponse.getGetCallerIdentityResponse().getGetCallerIdentityResult().getArn()).isEqualToIgnoringCase("test arn");
}
Expand Down
28 changes: 28 additions & 0 deletions src/test/java/com/nike/cerberus/aws/sts/AwsStsHttpHeaderTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.nike.cerberus.aws.sts;

import org.junit.Test;

import static junit.framework.TestCase.assertEquals;

public class AwsStsHttpHeaderTest {

@Test
public void test_getRegion_returns_region_as_expected() {
AwsStsHttpHeader header = new AwsStsHttpHeader(
"20180904T205115Z",
"FQoGZXIvYXdzEFYaDEYceadsfLKJLKlkj908098oB/rJIdxdo57fx3Ef2wW8WhFbSpLGg3hwNqhuepdkf/c0F7OXJutqM2yjgnZCiO7SPAdnMSJhoEgH7SJlkPaPfiRzZAf0yxxD6e4z0VJU74uQfbgfZpn5RL+JyDpgoYkUrjuyL8zRB1knGSOCi32Q75+asdfasd+7bWxMyJIKEb/HF2Le8xM/9F4WRqa5P0+asdfasdfasdf+MGlDlNG0KTzg1JT6QXf95ozWR5bBFSz5DbrFhXhMegMQ7+7Kvx+asdfasdl.jlkj++5NpRRlE54cct7+aG3HQskow9y73AU=",
"AWS4-HMAC-SHA256 Credential=ASIA5S2FQS2GYQLK5FFF/20180904/us-west-2/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ddb9417d2b9bfe6f8b03e31a8f5d8ab98e0f4alkj12312098asdf"
);

assertEquals("us-west-2", header.getRegion());

header = new AwsStsHttpHeader(
"20180904T205115Z",
"FQoGZXIvYXdzEFYaDEYceadsfLKJLKlkj908098oB/rJIdxdo57fx3Ef2wW8WhFbSpLGg3hwNqhuepdkf/c0F7OXJutqM2yjgnZCiO7SPAdnMSJhoEgH7SJlkPaPfiRzZAf0yxxD6e4z0VJU74uQfbgfZpn5RL+JyDpgoYkUrjuyL8zRB1knGSOCi32Q75+asdfasd+7bWxMyJIKEb/HF2Le8xM/9F4WRqa5P0+asdfasdfasdf+MGlDlNG0KTzg1JT6QXf95ozWR5bBFSz5DbrFhXhMegMQ7+7Kvx+asdfasdl.jlkj++5NpRRlE54cct7+aG3HQskow9y73AU=",
"AWS4-HMAC-SHA256 Credential=ASIA5S2FQS2GYQLK5FFF/20180904/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ddb9417d2b9bfe6f8b03e31a8f5d8ab98e0f4alkj12312098asdf"
);

assertEquals("us-east-1", header.getRegion());
}

}

0 comments on commit 2d70260

Please sign in to comment.