Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #15 from Nike-Inc/feature/use_v2_iam_auth_api
Browse files Browse the repository at this point in the history
Use new V2 API for IAM authentication
  • Loading branch information
sdford authored Apr 27, 2017
2 parents 07d30db + 9e22b79 commit 7d47139
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 160 deletions.
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# limitations under the License.
#

version=1.5.0
version=2.0.0
groupId=com.nike
artifactId=cerberus-client
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantReadWriteLock;

import static com.nike.cerberus.client.auth.aws.StaticIamRoleVaultCredentialsProvider.IAM_ROLE_ARN_FORMAT;

/**
* {@link VaultCredentialsProvider} implementation that uses some AWS
* credentials provider to authenticate with Cerberus and decrypt the auth
Expand Down Expand Up @@ -141,28 +143,28 @@ public VaultCredentials getCredentials() {
*
* @param accountId
* AWS account ID used to auth with cerberus
* @param iamRole
* @param iamRoleName
* IAM role name used to auth with cerberus
*/
protected void getAndSetToken(final String accountId, final String iamRole) {
getAndSetToken(accountId, iamRole, Regions.getCurrentRegion());
protected void getAndSetToken(final String accountId, final String iamRoleName) {
final String iamRoleArn = String.format(IAM_ROLE_ARN_FORMAT, accountId, iamRoleName);

getAndSetToken(iamRoleArn, Regions.getCurrentRegion());
}

/**
* Authenticates with Cerberus and decrypts and sets the token and expiration details.
*
* @param accountId
* AWS account ID used to auth with cerberus
* @param iamRole
* IAM role name used to auth with cerberus
* @param iamPrincipalArn
* AWS IAM principal ARN used to auth with cerberus
* @param region
* AWS Region used in auth with cerberus
*/
protected void getAndSetToken(final String accountId, final String iamRole, final Region region) {
protected void getAndSetToken(final String iamPrincipalArn, final Region region) {
final AWSKMSClient kmsClient = new AWSKMSClient();
kmsClient.setRegion(region);

final String encryptedAuthData = getEncryptedAuthData(accountId, iamRole, region);
final String encryptedAuthData = getEncryptedAuthData(iamPrincipalArn, region);
final VaultAuthResponse decryptedToken = decryptToken(kmsClient, encryptedAuthData);
final DateTime expires = DateTime.now(DateTimeZone.UTC);
expires.plusSeconds(decryptedToken.getLeaseDuration() - paddingTimeInSeconds);
Expand All @@ -173,16 +175,13 @@ protected void getAndSetToken(final String accountId, final String iamRole, fina

/**
* Retrieves the encrypted auth response from Cerberus.
*
* @param accountId
* AWS account ID used in the row key
* @param roleName
* IAM role name used in the row key
* @param iamPrincipalArn
* IAM principal ARN used in the row key
* @param region
* Current region of the running function or instance
* Current region of the running function or instance
* @return Base64 and encrypted token
*/
protected String getEncryptedAuthData(final String accountId, final String roleName, Region region) {
protected String getEncryptedAuthData(final String iamPrincipalArn, Region region) {
final String url = urlResolver.resolve();

if (StringUtils.isBlank(url)) {
Expand All @@ -191,14 +190,14 @@ protected String getEncryptedAuthData(final String accountId, final String roleN

final OkHttpClient httpClient = new OkHttpClient();

LOGGER.info(String.format("Attempting to authenticate with AWS account id [%s] and role [%s] against [%s]",
accountId, roleName, url));
LOGGER.info(String.format("Attempting to authenticate with AWS IAM principal ARN [%s] against [%s]",
iamPrincipalArn, url));

try {
Request.Builder requestBuilder = new Request.Builder().url(url + "/v1/auth/iam-role")
Request.Builder requestBuilder = new Request.Builder().url(url + "/v2/auth/iam-principal")
.addHeader(HttpHeader.ACCEPT, DEFAULT_MEDIA_TYPE.toString())
.addHeader(HttpHeader.CONTENT_TYPE, DEFAULT_MEDIA_TYPE.toString())
.method(HttpMethod.POST, buildCredentialsRequestBody(accountId, roleName, region));
.method(HttpMethod.POST, buildCredentialsRequestBody(iamPrincipalArn, region));

Response response = httpClient.newCall(requestBuilder.build()).execute();

Expand Down Expand Up @@ -249,12 +248,11 @@ protected VaultAuthResponse decryptToken(AWSKMS kmsClient, String encryptedToken
return gson.fromJson(decryptedAuthData, VaultAuthResponse.class);
}

private RequestBody buildCredentialsRequestBody(final String accountId, final String roleName, Region region) {
private RequestBody buildCredentialsRequestBody(final String iamPrincipalArn, Region region) {
final String regionName = region == null ? Regions.getCurrentRegion().getName() : region.getName();

final Map<String, String> credentials = new HashMap<>();
credentials.put("account_id", accountId);
credentials.put("role_name", roleName);
credentials.put("iam_principal_arn", iamPrincipalArn);
credentials.put("region", regionName);

return RequestBody.create(DEFAULT_MEDIA_TYPE, gson.toJson(credentials));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import com.nike.vault.client.UrlResolver;
import com.nike.vault.client.VaultClientException;

import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* This Credentials provider will look up the assigned InstanceProfileArn for this machine and attempt
* To automatically retrieve a Vault token from CMS's iam-auth endpoint that takes region, acct id, role name.
Expand All @@ -28,42 +25,16 @@ public InstanceProfileVaultCredentialsProvider(UrlResolver urlResolver) {
@Override
protected void authenticate() {
EC2MetadataUtils.IAMInfo iamInfo = getIamInfo();
IamAuthInfo iamAuthInfo = getIamAuthInfo(iamInfo.instanceProfileArn);
String instanceProfileArn = iamInfo.instanceProfileArn;
Region region = Regions.getCurrentRegion();

try {
getAndSetToken(iamAuthInfo.accountId, iamAuthInfo.roleName, iamAuthInfo.region);
getAndSetToken(instanceProfileArn, region);
} catch (Exception e) {
throw new VaultClientException(String.format("Failed to authenticate with Cerberus's iam auth endpoint " +
"using the following auth info, acct id: %s, roleName: %s, region: %s",
iamAuthInfo.accountId, iamAuthInfo.roleName, iamAuthInfo.region), e);
}
}

protected IamAuthInfo getIamAuthInfo(String instanceProfileArn) {
if (instanceProfileArn == null) {
throw new VaultClientException("instanceProfileArn provided was null rather than valid arn");
}

IamAuthInfo info = new IamAuthInfo();
String pattern = "arn:aws:iam::(.*?):instance-profile/(.*)";
Matcher matcher = Pattern.compile(pattern).matcher(instanceProfileArn);
boolean found = matcher.find();
if (! found) {
throw new VaultClientException(String.format(
"Failed to find account id and role / instance profile name from ARN: %s using pattern %s",
instanceProfileArn, pattern));
"using the following auth info, iamPrincipalArn: %s, region: %s",
instanceProfileArn, region), e);
}

info.accountId = matcher.group(1);
info.roleName = matcher.group(2);

return info;
}

protected static class IamAuthInfo {
String accountId;
String roleName;
Region region = Regions.getCurrentRegion();
}

protected EC2MetadataUtils.IAMInfo getIamInfo() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,15 @@ protected void authenticate() {
throw new IllegalStateException("Lambda function assigned role is not a valid IAM role ARN.");
}

final String accountId = roleArnMatcher.group("accountId");
final String iamRoleArn = roleArnMatcher.group("roleName");

try {
getAndSetToken(accountId, iamRoleArn, currentRegion);
getAndSetToken(roleArn, currentRegion);
return;
} catch (AmazonClientException ace) {
LOGGER.warn("Unexpected error communicating with AWS services.", ace);
} catch (JsonSyntaxException jse) {
LOGGER.error("The decrypted auth response was not in the expected format!", jse);
} catch (VaultClientException sce) {
LOGGER.warn("Unable to acquire Vault token for IAM role: " + iamRoleArn, sce);
LOGGER.warn("Unable to acquire Vault token for IAM role: " + roleArn, sce);
}

throw new VaultClientException("Unable to acquire token with Lambda instance role.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,94 +13,71 @@
*/
public class StaticIamRoleVaultCredentialsProvider extends BaseAwsCredentialsProvider {

protected String accountId;
protected String roleName;
public static final String IAM_ROLE_ARN_FORMAT = "arn:aws:iam::%s:role/%s";
protected String iamPrincipalArn;
protected Region region;

public StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver, String accountId, String roleName, String region) {
this(urlResolver);
this.accountId = accountId;
this.roleName = roleName;
this.iamPrincipalArn = generateIamRoleArn(accountId, roleName);
this.region = Region.getRegion(Regions.fromName(region));
}

public StaticIamRoleVaultCredentialsProvider(String vaultUrl, String accountId, String roleName, String region) {
this(new StaticVaultUrlResolver(vaultUrl));
this.accountId = accountId;
this.roleName = roleName;
this.iamPrincipalArn = generateIamRoleArn(accountId, roleName);
this.region = Region.getRegion(Regions.fromName(region));
}

public StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver, String accountId, String roleName, Region region) {
this(urlResolver);
this.accountId = accountId;
this.roleName = roleName;
this.iamPrincipalArn = generateIamRoleArn(accountId, roleName);
this.region = region;
}

public StaticIamRoleVaultCredentialsProvider(String vaultUrl, String accountId, String roleName, Region region) {
this(new StaticVaultUrlResolver(vaultUrl));
this.accountId = accountId;
this.roleName = roleName;
this.iamPrincipalArn = generateIamRoleArn(accountId, roleName);
this.region = region;
}

public StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver, String iamRoleArn, String region) {
this(urlResolver);
this.accountId = getAccountIdFromArn(iamRoleArn);
this.roleName = getRoleNameFromArn(iamRoleArn);
this.iamPrincipalArn = iamRoleArn;
this.region = Region.getRegion(Regions.fromName(region));
}


public StaticIamRoleVaultCredentialsProvider(String vaultUrl, String iamRoleArn, String region) {
this(new StaticVaultUrlResolver(vaultUrl));
this.accountId = getAccountIdFromArn(iamRoleArn);
this.roleName = getRoleNameFromArn(iamRoleArn);
this.iamPrincipalArn = iamRoleArn;
this.region = Region.getRegion(Regions.fromName(region));
}


public StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver, String iamRoleArn, Region region) {
this(urlResolver);
this.accountId = getAccountIdFromArn(iamRoleArn);
this.roleName = getRoleNameFromArn(iamRoleArn);
this.iamPrincipalArn = iamRoleArn;
this.region = region;
}

public StaticIamRoleVaultCredentialsProvider(String vaultUrl, String iamRoleArn, Region region) {
this(new StaticVaultUrlResolver(vaultUrl));
this.accountId = getAccountIdFromArn(iamRoleArn);
this.roleName = getRoleNameFromArn(iamRoleArn);
this.iamPrincipalArn = iamRoleArn;
this.region = region;
}

private StaticIamRoleVaultCredentialsProvider(UrlResolver urlResolver) {
super(urlResolver);
}

private String getAccountIdFromArn(String arn) {
Matcher matcher = Pattern.compile("arn:aws:iam::(.*?):role.*").matcher(arn);
boolean found = matcher.find();
if (found) {
return matcher.group(1);
}
private String generateIamRoleArn(String accountId, String roleName) {

throw new IllegalArgumentException("Invalid IAM role ARN supplied, expected arn:aws:iam::%s:role/%s");
}

private String getRoleNameFromArn(String arn) {
Matcher matcher = Pattern.compile("arn:aws:iam::.*?:role/(.*)").matcher(arn);
boolean found = matcher.find();
if (found) {
return matcher.group(1);
}

throw new IllegalArgumentException("Invalid IAM role ARN supplied, expected arn:aws:iam::%s:role/%s");
return String.format(IAM_ROLE_ARN_FORMAT, accountId, roleName);
}

@Override
protected void authenticate() {
getAndSetToken(accountId, roleName, region);
getAndSetToken(iamPrincipalArn, region);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

public class BaseAwsCredentialsProviderTest extends BaseCredentialsProviderTest{
public static final Region REGION = RegionUtils.getRegion("us-west-2");
public static final String ACCOUNT_ID = "123456789012";
public static final String CERBERUS_TEST_ROLE = "cerberus-test-role";
public static final String CERBERUS_TEST_ARN = "arn:aws:iam::123456789012:role/cerberus-test-role";
public static final String ERROR_RESPONSE = "Error calling vault";

protected static final String MISSING_AUTH_DATA = "{}";
Expand Down Expand Up @@ -56,7 +55,7 @@ public void tearDown() throws Exception {
public void getEncryptedAuthData_blank_url_throws_exception() throws Exception {
when(urlResolver.resolve()).thenReturn("");

provider.getEncryptedAuthData(ACCOUNT_ID, CERBERUS_TEST_ROLE, REGION);
provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION);
}

@Test(expected = VaultClientException.class)
Expand All @@ -71,7 +70,7 @@ public void getEncryptedAuthData_throws_exception_on_bad_response_code() throws
System.setProperty(DefaultCerberusUrlResolver.CERBERUS_ADDR_SYS_PROPERTY, vaultUrl);
mockWebServer.enqueue(new MockResponse().setResponseCode(400).setBody(ERROR_RESPONSE));

provider.getEncryptedAuthData(ACCOUNT_ID, CERBERUS_TEST_ROLE, REGION);
provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION);
}

@Test(expected = VaultClientException.class)
Expand All @@ -81,7 +80,7 @@ public void getEncryptedAuthData_throws_exception_on_missing_auth_data() throws
System.setProperty(DefaultCerberusUrlResolver.CERBERUS_ADDR_SYS_PROPERTY, vaultUrl);
mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(MISSING_AUTH_DATA));

provider.getEncryptedAuthData(ACCOUNT_ID, CERBERUS_TEST_ROLE, REGION);
provider.getEncryptedAuthData(CERBERUS_TEST_ARN, REGION);
}

class TestAwsCredentialsProvider extends BaseAwsCredentialsProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,12 @@ public void before() {
provider = new InstanceProfileVaultCredentialsProvider(new StaticVaultUrlResolver("foo"));
}

@Test
public void test_that_valid_arn_gets_parsed() {
InstanceProfileVaultCredentialsProvider.IamAuthInfo info = provider.getIamAuthInfo("arn:aws:iam::1234:inst" +
"ance-profile/base/prod-base-sdfgsdfg-be5c-47ff-b82f-sdfgsdfgsfdg-CmsInstanceProfile-sdfgsdfgsdfg");

assertEquals("1234", info.accountId);
assertEquals("base/prod-base-sdfgsdfg-be5c-47ff-b82f-sdfgsdfgsfdg-CmsInstanceProfile-sdfgsdfgsdfg",
info.roleName);
}

@Test(expected = VaultClientException.class)
public void test_that_invalid_arn_fails() {
provider.getIamAuthInfo("");
}

@Test(expected = VaultClientException.class)
public void test_that_null_arn_fails() {
provider.getIamAuthInfo(null);
}

@Test(expected = VaultClientException.class)
public void test_that_authenticate_catches_exceptions_and_throws_vault_exception() {
InstanceProfileVaultCredentialsProvider providerSpy = spy(provider);

doThrow(new RuntimeException("Foo")).when(providerSpy).getAndSetToken(anyString(), anyString(), any(Region.class));
doReturn(new InstanceProfileVaultCredentialsProvider.IamAuthInfo()).when(providerSpy).getIamAuthInfo(anyString());
doThrow(new RuntimeException("Foo")).when(providerSpy).getAndSetToken(anyString(), any(Region.class));

EC2MetadataUtils.IAMInfo iamInfo = new EC2MetadataUtils.IAMInfo();
iamInfo.instanceProfileArn = "foo";
doReturn(iamInfo).when(providerSpy).getIamInfo();
Expand Down
Loading

0 comments on commit 7d47139

Please sign in to comment.