diff --git a/gradle.properties b/gradle.properties index 21712d5..15d6e28 100644 --- a/gradle.properties +++ b/gradle.properties @@ -14,6 +14,6 @@ # limitations under the License. # -version=1.5.0 +version=2.0.0 groupId=com.nike artifactId=cerberus-client diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java index 7cab28b..3347a86 100644 --- a/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java +++ b/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java @@ -59,6 +59,8 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import static com.nike.cerberus.client.auth.aws.StaticIamRoleVaultCredentialsProvider.IAM_ROLE_ARN_FORMAT; + /** * {@link VaultCredentialsProvider} implementation that uses some AWS * credentials provider to authenticate with Cerberus and decrypt the auth @@ -141,28 +143,28 @@ public VaultCredentials getCredentials() { * * @param accountId * AWS account ID used to auth with cerberus - * @param iamRole + * @param iamRoleName * IAM role name used to auth with cerberus */ - protected void getAndSetToken(final String accountId, final String iamRole) { - getAndSetToken(accountId, iamRole, Regions.getCurrentRegion()); + protected void getAndSetToken(final String accountId, final String iamRoleName) { + final String iamRoleArn = String.format(IAM_ROLE_ARN_FORMAT, accountId, iamRoleName); + + getAndSetToken(iamRoleArn, Regions.getCurrentRegion()); } /** * Authenticates with Cerberus and decrypts and sets the token and expiration details. * - * @param accountId - * AWS account ID used to auth with cerberus - * @param iamRole - * IAM role name used to auth with cerberus + * @param iamPrincipalArn + * AWS IAM principal ARN used to auth with cerberus * @param region * AWS Region used in auth with cerberus */ - protected void getAndSetToken(final String accountId, final String iamRole, final Region region) { + protected void getAndSetToken(final String iamPrincipalArn, final Region region) { final AWSKMSClient kmsClient = new AWSKMSClient(); kmsClient.setRegion(region); - final String encryptedAuthData = getEncryptedAuthData(accountId, iamRole, region); + final String encryptedAuthData = getEncryptedAuthData(iamPrincipalArn, region); final VaultAuthResponse decryptedToken = decryptToken(kmsClient, encryptedAuthData); final DateTime expires = DateTime.now(DateTimeZone.UTC); expires.plusSeconds(decryptedToken.getLeaseDuration() - paddingTimeInSeconds); @@ -173,16 +175,13 @@ protected void getAndSetToken(final String accountId, final String iamRole, fina /** * Retrieves the encrypted auth response from Cerberus. - * - * @param accountId - * AWS account ID used in the row key - * @param roleName - * IAM role name used in the row key + * @param iamPrincipalArn + * IAM principal ARN used in the row key * @param region - * Current region of the running function or instance + * Current region of the running function or instance * @return Base64 and encrypted token */ - protected String getEncryptedAuthData(final String accountId, final String roleName, Region region) { + protected String getEncryptedAuthData(final String iamPrincipalArn, Region region) { final String url = urlResolver.resolve(); if (StringUtils.isBlank(url)) { @@ -191,14 +190,14 @@ protected String getEncryptedAuthData(final String accountId, final String roleN final OkHttpClient httpClient = new OkHttpClient(); - LOGGER.info(String.format("Attempting to authenticate with AWS account id [%s] and role [%s] against [%s]", - accountId, roleName, url)); + LOGGER.info(String.format("Attempting to authenticate with AWS IAM principal ARN [%s] against [%s]", + iamPrincipalArn, url)); try { - Request.Builder requestBuilder = new Request.Builder().url(url + "/v1/auth/iam-role") + Request.Builder requestBuilder = new Request.Builder().url(url + "/v2/auth/iam-principal") .addHeader(HttpHeader.ACCEPT, DEFAULT_MEDIA_TYPE.toString()) .addHeader(HttpHeader.CONTENT_TYPE, DEFAULT_MEDIA_TYPE.toString()) - .method(HttpMethod.POST, buildCredentialsRequestBody(accountId, roleName, region)); + .method(HttpMethod.POST, buildCredentialsRequestBody(iamPrincipalArn, region)); Response response = httpClient.newCall(requestBuilder.build()).execute(); @@ -249,12 +248,11 @@ protected VaultAuthResponse decryptToken(AWSKMS kmsClient, String encryptedToken return gson.fromJson(decryptedAuthData, VaultAuthResponse.class); } - private RequestBody buildCredentialsRequestBody(final String accountId, final String roleName, Region region) { + private RequestBody buildCredentialsRequestBody(final String iamPrincipalArn, Region region) { final String regionName = region == null ? Regions.getCurrentRegion().getName() : region.getName(); final Map credentials = new HashMap<>(); - credentials.put("account_id", accountId); - credentials.put("role_name", roleName); + credentials.put("iam_principal_arn", iamPrincipalArn); credentials.put("region", regionName); return RequestBody.create(DEFAULT_MEDIA_TYPE, gson.toJson(credentials)); diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProvider.java index e11dcb0..ea6aa31 100644 --- a/src/main/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProvider.java +++ b/src/main/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProvider.java @@ -6,9 +6,6 @@ import com.nike.vault.client.UrlResolver; import com.nike.vault.client.VaultClientException; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - /** * This Credentials provider will look up the assigned InstanceProfileArn for this machine and attempt * To automatically retrieve a Vault token from CMS's iam-auth endpoint that takes region, acct id, role name. @@ -28,42 +25,16 @@ public InstanceProfileVaultCredentialsProvider(UrlResolver urlResolver) { @Override protected void authenticate() { EC2MetadataUtils.IAMInfo iamInfo = getIamInfo(); - IamAuthInfo iamAuthInfo = getIamAuthInfo(iamInfo.instanceProfileArn); + String instanceProfileArn = iamInfo.instanceProfileArn; + Region region = Regions.getCurrentRegion(); try { - getAndSetToken(iamAuthInfo.accountId, iamAuthInfo.roleName, iamAuthInfo.region); + getAndSetToken(instanceProfileArn, region); } catch (Exception e) { throw new VaultClientException(String.format("Failed to authenticate with Cerberus's iam auth endpoint " + - "using the following auth info, acct id: %s, roleName: %s, region: %s", - iamAuthInfo.accountId, iamAuthInfo.roleName, iamAuthInfo.region), e); - } - } - - protected IamAuthInfo getIamAuthInfo(String instanceProfileArn) { - if (instanceProfileArn == null) { - throw new VaultClientException("instanceProfileArn provided was null rather than valid arn"); - } - - IamAuthInfo info = new IamAuthInfo(); - String pattern = "arn:aws:iam::(.*?):instance-profile/(.*)"; - Matcher matcher = Pattern.compile(pattern).matcher(instanceProfileArn); - boolean found = matcher.find(); - if (! found) { - throw new VaultClientException(String.format( - "Failed to find account id and role / instance profile name from ARN: %s using pattern %s", - instanceProfileArn, pattern)); + "using the following auth info, iamPrincipalArn: %s, region: %s", + instanceProfileArn, region), e); } - - info.accountId = matcher.group(1); - info.roleName = matcher.group(2); - - return info; - } - - protected static class IamAuthInfo { - String accountId; - String roleName; - Region region = Regions.getCurrentRegion(); } protected EC2MetadataUtils.IAMInfo getIamInfo() { diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProvider.java index a949a6c..11eb6b6 100644 --- a/src/main/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProvider.java +++ b/src/main/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProvider.java @@ -103,18 +103,15 @@ protected void authenticate() { throw new IllegalStateException("Lambda function assigned role is not a valid IAM role ARN."); } - final String accountId = roleArnMatcher.group("accountId"); - final String iamRoleArn = roleArnMatcher.group("roleName"); - try { - getAndSetToken(accountId, iamRoleArn, currentRegion); + getAndSetToken(roleArn, currentRegion); return; } catch (AmazonClientException ace) { LOGGER.warn("Unexpected error communicating with AWS services.", ace); } catch (JsonSyntaxException jse) { LOGGER.error("The decrypted auth response was not in the expected format!", jse); } catch (VaultClientException sce) { - LOGGER.warn("Unable to acquire Vault token for IAM role: " + iamRoleArn, sce); + LOGGER.warn("Unable to acquire Vault token for IAM role: " + roleArn, sce); } throw new VaultClientException("Unable to acquire token with Lambda instance role."); diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProvider.java index 7b8cc6b..93d6c34 100644 --- a/src/main/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProvider.java +++ b/src/main/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProvider.java @@ -13,65 +13,57 @@ */ public class StaticIamRoleVaultCredentialsProvider extends BaseAwsCredentialsProvider { - protected String accountId; - protected String roleName; + public static final String IAM_ROLE_ARN_FORMAT = "arn:aws:iam::%s:role/%s"; + protected String iamPrincipalArn; protected Region region; public StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver, String accountId, String roleName, String region) { this(urlResolver); - this.accountId = accountId; - this.roleName = roleName; + this.iamPrincipalArn = generateIamRoleArn(accountId, roleName); this.region = Region.getRegion(Regions.fromName(region)); } public StaticIamRoleVaultCredentialsProvider(String vaultUrl, String accountId, String roleName, String region) { this(new StaticVaultUrlResolver(vaultUrl)); - this.accountId = accountId; - this.roleName = roleName; + this.iamPrincipalArn = generateIamRoleArn(accountId, roleName); this.region = Region.getRegion(Regions.fromName(region)); } public StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver, String accountId, String roleName, Region region) { this(urlResolver); - this.accountId = accountId; - this.roleName = roleName; + this.iamPrincipalArn = generateIamRoleArn(accountId, roleName); this.region = region; } public StaticIamRoleVaultCredentialsProvider(String vaultUrl, String accountId, String roleName, Region region) { this(new StaticVaultUrlResolver(vaultUrl)); - this.accountId = accountId; - this.roleName = roleName; + this.iamPrincipalArn = generateIamRoleArn(accountId, roleName); this.region = region; } public StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver, String iamRoleArn, String region) { this(urlResolver); - this.accountId = getAccountIdFromArn(iamRoleArn); - this.roleName = getRoleNameFromArn(iamRoleArn); + this.iamPrincipalArn = iamRoleArn; this.region = Region.getRegion(Regions.fromName(region)); } public StaticIamRoleVaultCredentialsProvider(String vaultUrl, String iamRoleArn, String region) { this(new StaticVaultUrlResolver(vaultUrl)); - this.accountId = getAccountIdFromArn(iamRoleArn); - this.roleName = getRoleNameFromArn(iamRoleArn); + this.iamPrincipalArn = iamRoleArn; this.region = Region.getRegion(Regions.fromName(region)); } public StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver, String iamRoleArn, Region region) { this(urlResolver); - this.accountId = getAccountIdFromArn(iamRoleArn); - this.roleName = getRoleNameFromArn(iamRoleArn); + this.iamPrincipalArn = iamRoleArn; this.region = region; } public StaticIamRoleVaultCredentialsProvider(String vaultUrl, String iamRoleArn, Region region) { this(new StaticVaultUrlResolver(vaultUrl)); - this.accountId = getAccountIdFromArn(iamRoleArn); - this.roleName = getRoleNameFromArn(iamRoleArn); + this.iamPrincipalArn = iamRoleArn; this.region = region; } @@ -79,28 +71,13 @@ private StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver) { super(urlResolver); } - private String getAccountIdFromArn(String arn) { - Matcher matcher = Pattern.compile("arn:aws:iam::(.*?):role.*").matcher(arn); - boolean found = matcher.find(); - if (found) { - return matcher.group(1); - } + private String generateIamRoleArn(String accountId, String roleName) { - throw new IllegalArgumentException("Invalid IAM role ARN supplied, expected arn:aws:iam::%s:role/%s"); - } - - private String getRoleNameFromArn(String arn) { - Matcher matcher = Pattern.compile("arn:aws:iam::.*?:role/(.*)").matcher(arn); - boolean found = matcher.find(); - if (found) { - return matcher.group(1); - } - - throw new IllegalArgumentException("Invalid IAM role ARN supplied, expected arn:aws:iam::%s:role/%s"); + return String.format(IAM_ROLE_ARN_FORMAT, accountId, roleName); } @Override protected void authenticate() { - getAndSetToken(accountId, roleName, region); + getAndSetToken(iamPrincipalArn, region); } } diff --git a/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java b/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java index a2584e3..732cd11 100644 --- a/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java +++ b/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java @@ -23,8 +23,7 @@ public class BaseAwsCredentialsProviderTest extends BaseCredentialsProviderTest{ public static final Region REGION = RegionUtils.getRegion("us-west-2"); - public static final String ACCOUNT_ID = "123456789012"; - public static final String CERBERUS_TEST_ROLE = "cerberus-test-role"; + public static final String CERBERUS_TEST_ARN = "arn:aws:iam::123456789012:role/cerberus-test-role"; public static final String ERROR_RESPONSE = "Error calling vault"; protected static final String MISSING_AUTH_DATA = "{}"; @@ -56,7 +55,7 @@ public void tearDown() throws Exception { public void getEncryptedAuthData_blank_url_throws_exception() throws Exception { when(urlResolver.resolve()).thenReturn(""); - provider.getEncryptedAuthData(ACCOUNT_ID, CERBERUS_TEST_ROLE, REGION); + provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION); } @Test(expected = VaultClientException.class) @@ -71,7 +70,7 @@ public void getEncryptedAuthData_throws_exception_on_bad_response_code() throws System.setProperty(DefaultCerberusUrlResolver.CERBERUS_ADDR_SYS_PROPERTY, vaultUrl); mockWebServer.enqueue(new MockResponse().setResponseCode(400).setBody(ERROR_RESPONSE)); - provider.getEncryptedAuthData(ACCOUNT_ID, CERBERUS_TEST_ROLE, REGION); + provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION); } @Test(expected = VaultClientException.class) @@ -81,7 +80,7 @@ public void getEncryptedAuthData_throws_exception_on_missing_auth_data() throws System.setProperty(DefaultCerberusUrlResolver.CERBERUS_ADDR_SYS_PROPERTY, vaultUrl); mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(MISSING_AUTH_DATA)); - provider.getEncryptedAuthData(ACCOUNT_ID, CERBERUS_TEST_ROLE, REGION); + provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION); } class TestAwsCredentialsProvider extends BaseAwsCredentialsProvider { diff --git a/src/test/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProviderTest.java b/src/test/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProviderTest.java index de7fa28..7719752 100644 --- a/src/test/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProviderTest.java +++ b/src/test/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProviderTest.java @@ -25,32 +25,12 @@ public void before() { provider = new InstanceProfileVaultCredentialsProvider(new StaticVaultUrlResolver("foo")); } - @Test - public void test_that_valid_arn_gets_parsed() { - InstanceProfileVaultCredentialsProvider.IamAuthInfo info = provider.getIamAuthInfo("arn:aws:iam::1234:inst" + - "ance-profile/base/prod-base-sdfgsdfg-be5c-47ff-b82f-sdfgsdfgsfdg-CmsInstanceProfile-sdfgsdfgsdfg"); - - assertEquals("1234", info.accountId); - assertEquals("base/prod-base-sdfgsdfg-be5c-47ff-b82f-sdfgsdfgsfdg-CmsInstanceProfile-sdfgsdfgsdfg", - info.roleName); - } - - @Test(expected = VaultClientException.class) - public void test_that_invalid_arn_fails() { - provider.getIamAuthInfo(""); - } - - @Test(expected = VaultClientException.class) - public void test_that_null_arn_fails() { - provider.getIamAuthInfo(null); - } - @Test(expected = VaultClientException.class) public void test_that_authenticate_catches_exceptions_and_throws_vault_exception() { InstanceProfileVaultCredentialsProvider providerSpy = spy(provider); - doThrow(new RuntimeException("Foo")).when(providerSpy).getAndSetToken(anyString(), anyString(), any(Region.class)); - doReturn(new InstanceProfileVaultCredentialsProvider.IamAuthInfo()).when(providerSpy).getIamAuthInfo(anyString()); + doThrow(new RuntimeException("Foo")).when(providerSpy).getAndSetToken(anyString(), any(Region.class)); + EC2MetadataUtils.IAMInfo iamInfo = new EC2MetadataUtils.IAMInfo(); iamInfo.instanceProfileArn = "foo"; doReturn(iamInfo).when(providerSpy).getIamInfo(); diff --git a/src/test/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProviderTest.java b/src/test/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProviderTest.java index f1b2f71..7b33dbe 100644 --- a/src/test/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProviderTest.java +++ b/src/test/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProviderTest.java @@ -24,8 +24,7 @@ public void test_constructor_1() { REGION_STRING ); - assertEquals(ACCOUNT_ID, provider.accountId); - assertEquals(ROLE_NAME, provider.roleName); + assertEquals(ROLE_ARN, provider.iamPrincipalArn); assertEquals(REGION, provider.region); } @@ -38,8 +37,7 @@ public void test_constructor_2() { REGION_STRING ); - assertEquals(ACCOUNT_ID, provider.accountId); - assertEquals(ROLE_NAME, provider.roleName); + assertEquals(ROLE_ARN, provider.iamPrincipalArn); assertEquals(REGION, provider.region); } @@ -52,8 +50,7 @@ public void test_constructor_3() { REGION ); - assertEquals(ACCOUNT_ID, provider.accountId); - assertEquals(ROLE_NAME, provider.roleName); + assertEquals(ROLE_ARN, provider.iamPrincipalArn); assertEquals(REGION, provider.region); } @@ -66,8 +63,7 @@ public void test_constructor_4() { REGION ); - assertEquals(ACCOUNT_ID, provider.accountId); - assertEquals(ROLE_NAME, provider.roleName); + assertEquals(ROLE_ARN, provider.iamPrincipalArn); assertEquals(REGION, provider.region); } @@ -79,8 +75,7 @@ public void test_constructor_5() { REGION_STRING ); - assertEquals(ACCOUNT_ID, provider.accountId); - assertEquals(ROLE_NAME, provider.roleName); + assertEquals(ROLE_ARN, provider.iamPrincipalArn); assertEquals(REGION, provider.region); } @@ -92,8 +87,7 @@ public void test_constructor_6() { REGION_STRING ); - assertEquals(ACCOUNT_ID, provider.accountId); - assertEquals(ROLE_NAME, provider.roleName); + assertEquals(ROLE_ARN, provider.iamPrincipalArn); assertEquals(REGION, provider.region); } @@ -105,8 +99,7 @@ public void test_constructor_7() { REGION ); - assertEquals(ACCOUNT_ID, provider.accountId); - assertEquals(ROLE_NAME, provider.roleName); + assertEquals(ROLE_ARN, provider.iamPrincipalArn); assertEquals(REGION, provider.region); } @@ -118,27 +111,8 @@ public void test_constructor_8() { REGION ); - assertEquals(ACCOUNT_ID, provider.accountId); - assertEquals(ROLE_NAME, provider.roleName); + assertEquals(ROLE_ARN, provider.iamPrincipalArn); assertEquals(REGION, provider.region); } - @Test(expected = IllegalArgumentException.class) - public void test_constructor_bad_arn1() { - StaticIamRoleVaultCredentialsProvider provider = new StaticIamRoleVaultCredentialsProvider( - "foo", - "foo", - REGION - ); - } - - @Test(expected = IllegalArgumentException.class) - public void test_constructor_bad_arn2() { - StaticIamRoleVaultCredentialsProvider provider = new StaticIamRoleVaultCredentialsProvider( - "foo", - "arn:aws:iam::123:rolefoo", - REGION - ); - } - }