diff --git a/gradle/dependencies.gradle b/gradle/dependencies.gradle index 31b86f0..b476b15 100644 --- a/gradle/dependencies.gradle +++ b/gradle/dependencies.gradle @@ -31,7 +31,7 @@ dependencies { * We do this because AWS constantly breaks backwards compatibility of their SDK with minor version releases. * * We do not want to dictate what SDK version users of Cerberus should use. * ***********************************************************************************************************************/ - compile "com.nike:vault-client:1.0.0" + compile "com.nike:vault-client:1.1.0" compile "joda-time:joda-time:2.8.1" compile "org.apache.commons:commons-lang3:3.4" compile "org.slf4j:slf4j-api:1.7.14" diff --git a/src/main/java/com/nike/cerberus/client/auth/DefaultCerberusCredentialsProviderChain.java b/src/main/java/com/nike/cerberus/client/auth/DefaultCerberusCredentialsProviderChain.java index cefb9e4..e7049ef 100644 --- a/src/main/java/com/nike/cerberus/client/auth/DefaultCerberusCredentialsProviderChain.java +++ b/src/main/java/com/nike/cerberus/client/auth/DefaultCerberusCredentialsProviderChain.java @@ -17,6 +17,7 @@ package com.nike.cerberus.client.auth; import com.nike.cerberus.client.DefaultCerberusUrlResolver; +import com.nike.cerberus.client.auth.aws.InstanceProfileVaultCredentialsProvider; import com.nike.cerberus.client.auth.aws.InstanceRoleVaultCredentialsProvider; import com.nike.vault.client.UrlResolver; import com.nike.vault.client.auth.VaultCredentialsProviderChain; @@ -51,6 +52,7 @@ public DefaultCerberusCredentialsProviderChain() { public DefaultCerberusCredentialsProviderChain(UrlResolver urlResolver) { super(new EnvironmentCerberusCredentialsProvider(), new SystemPropertyCerberusCredentialsProvider(), + new InstanceProfileVaultCredentialsProvider(urlResolver), new InstanceRoleVaultCredentialsProvider(urlResolver)); } } diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProvider.java new file mode 100644 index 0000000..e11dcb0 --- /dev/null +++ b/src/main/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProvider.java @@ -0,0 +1,72 @@ +package com.nike.cerberus.client.auth.aws; + +import com.amazonaws.regions.Region; +import com.amazonaws.regions.Regions; +import com.amazonaws.util.EC2MetadataUtils; +import com.nike.vault.client.UrlResolver; +import com.nike.vault.client.VaultClientException; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * This Credentials provider will look up the assigned InstanceProfileArn for this machine and attempt + * To automatically retrieve a Vault token from CMS's iam-auth endpoint that takes region, acct id, role name. + */ +public class InstanceProfileVaultCredentialsProvider extends BaseAwsCredentialsProvider { + + /** + * Constructor to setup credentials provider using the specified + * implementation of {@link UrlResolver} + * + * @param urlResolver Resolver for resolving the Cerberus URL + */ + public InstanceProfileVaultCredentialsProvider(UrlResolver urlResolver) { + super(urlResolver); + } + + @Override + protected void authenticate() { + EC2MetadataUtils.IAMInfo iamInfo = getIamInfo(); + IamAuthInfo iamAuthInfo = getIamAuthInfo(iamInfo.instanceProfileArn); + + try { + getAndSetToken(iamAuthInfo.accountId, iamAuthInfo.roleName, iamAuthInfo.region); + } catch (Exception e) { + throw new VaultClientException(String.format("Failed to authenticate with Cerberus's iam auth endpoint " + + "using the following auth info, acct id: %s, roleName: %s, region: %s", + iamAuthInfo.accountId, iamAuthInfo.roleName, iamAuthInfo.region), e); + } + } + + protected IamAuthInfo getIamAuthInfo(String instanceProfileArn) { + if (instanceProfileArn == null) { + throw new VaultClientException("instanceProfileArn provided was null rather than valid arn"); + } + + IamAuthInfo info = new IamAuthInfo(); + String pattern = "arn:aws:iam::(.*?):instance-profile/(.*)"; + Matcher matcher = Pattern.compile(pattern).matcher(instanceProfileArn); + boolean found = matcher.find(); + if (! found) { + throw new VaultClientException(String.format( + "Failed to find account id and role / instance profile name from ARN: %s using pattern %s", + instanceProfileArn, pattern)); + } + + info.accountId = matcher.group(1); + info.roleName = matcher.group(2); + + return info; + } + + protected static class IamAuthInfo { + String accountId; + String roleName; + Region region = Regions.getCurrentRegion(); + } + + protected EC2MetadataUtils.IAMInfo getIamInfo() { + return EC2MetadataUtils.getIAMInstanceProfileInfo(); + } +} diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProvider.java index 63ee412..7b8cc6b 100644 --- a/src/main/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProvider.java +++ b/src/main/java/com/nike/cerberus/client/auth/aws/StaticIamRoleVaultCredentialsProvider.java @@ -80,20 +80,20 @@ private StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver) { } private String getAccountIdFromArn(String arn) { - Matcher m = Pattern.compile("arn:aws:iam::(.*?):role.*").matcher(arn); - boolean found = m.find(); + Matcher matcher = Pattern.compile("arn:aws:iam::(.*?):role.*").matcher(arn); + boolean found = matcher.find(); if (found) { - return m.group(1); + return matcher.group(1); } throw new IllegalArgumentException("Invalid IAM role ARN supplied, expected arn:aws:iam::%s:role/%s"); } private String getRoleNameFromArn(String arn) { - Matcher m = Pattern.compile("arn:aws:iam::.*?:role/(.*)").matcher(arn); - boolean found = m.find(); + Matcher matcher = Pattern.compile("arn:aws:iam::.*?:role/(.*)").matcher(arn); + boolean found = matcher.find(); if (found) { - return m.group(1); + return matcher.group(1); } throw new IllegalArgumentException("Invalid IAM role ARN supplied, expected arn:aws:iam::%s:role/%s"); diff --git a/src/test/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProviderTest.java b/src/test/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProviderTest.java new file mode 100644 index 0000000..de7fa28 --- /dev/null +++ b/src/test/java/com/nike/cerberus/client/auth/aws/InstanceProfileVaultCredentialsProviderTest.java @@ -0,0 +1,60 @@ +package com.nike.cerberus.client.auth.aws; + +import com.amazonaws.regions.Region; +import com.amazonaws.util.EC2MetadataUtils; +import com.nike.vault.client.StaticVaultUrlResolver; +import com.nike.vault.client.VaultClientException; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +public class InstanceProfileVaultCredentialsProviderTest { + + InstanceProfileVaultCredentialsProvider provider; + + @Before + public void before() { + provider = new InstanceProfileVaultCredentialsProvider(new StaticVaultUrlResolver("foo")); + } + + @Test + public void test_that_valid_arn_gets_parsed() { + InstanceProfileVaultCredentialsProvider.IamAuthInfo info = provider.getIamAuthInfo("arn:aws:iam::1234:inst" + + "ance-profile/base/prod-base-sdfgsdfg-be5c-47ff-b82f-sdfgsdfgsfdg-CmsInstanceProfile-sdfgsdfgsdfg"); + + assertEquals("1234", info.accountId); + assertEquals("base/prod-base-sdfgsdfg-be5c-47ff-b82f-sdfgsdfgsfdg-CmsInstanceProfile-sdfgsdfgsdfg", + info.roleName); + } + + @Test(expected = VaultClientException.class) + public void test_that_invalid_arn_fails() { + provider.getIamAuthInfo(""); + } + + @Test(expected = VaultClientException.class) + public void test_that_null_arn_fails() { + provider.getIamAuthInfo(null); + } + + @Test(expected = VaultClientException.class) + public void test_that_authenticate_catches_exceptions_and_throws_vault_exception() { + InstanceProfileVaultCredentialsProvider providerSpy = spy(provider); + + doThrow(new RuntimeException("Foo")).when(providerSpy).getAndSetToken(anyString(), anyString(), any(Region.class)); + doReturn(new InstanceProfileVaultCredentialsProvider.IamAuthInfo()).when(providerSpy).getIamAuthInfo(anyString()); + EC2MetadataUtils.IAMInfo iamInfo = new EC2MetadataUtils.IAMInfo(); + iamInfo.instanceProfileArn = "foo"; + doReturn(iamInfo).when(providerSpy).getIamInfo(); + + providerSpy.authenticate(); + } +}