diff --git a/README.md b/README.md index 7787761..19722e2 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,13 @@ This library acts as a wrapper around the Nike developed Vault client by configu ## Quickstart +### Default Client + ``` java final VaultClient vaultClient = DefaultCerberusClientFactory.getClient(); ``` -### Default URL Assumptions +#### Default URL Assumptions The example above uses the `DefaultCerberusUrlResolver` to resolve the URL for Vault. @@ -27,7 +29,7 @@ or the JVM system property, `cerberus.addr`, must be set: cerberus.addr=https://cerberus -### Default Credentials Provider Assumptions +#### Default Credentials Provider Assumptions Again, for the example above, the `DefaultCerberusCredentialsProviderChain` is used to resolve the token needed to interact with Vault. @@ -39,13 +41,54 @@ or the JVM system property, `vault.token`, must be set: cerberus.token=TOKEN -or the IAM role authentication flow: +or the EC2 IAM role authentication flow: If the client library is running on an EC2 instance, it will attempt to use the instance's assigned IAM role to authenticate with Cerberus and obtain a token. The IAM role must be configured for access to Cerberus before this will work. +The following policy statement must also be assigned to the IAM role, so that the client can automatically decrypt the auth token from the Cerberus IAM auth endpoint: + +``` json + { + "Sid": "allow-kms-decrypt", + "Effect": "Allow", + "Action": [ + "kms:Decrypt" + ], + "Resource": [ + "*" + ] + } +``` + +### Client that can authenticate from Lambdas + +#### Prerequisites + +The IAM role assigned to the Lambda function must contain the following policy statement in addition to the above KMS decrypt policy, this is so the Lambda can look up its metadata to automatically authenticate with the Cerberus IAM auth endpoint: + +``` json + { + "Sid": "allow-get-function-config", + "Effect": "Allow", + "Action": [ + "lambda:GetFunctionConfiguration" + ], + "Resource": [ + "*" + ] + } +``` + +#### Configure the Client + +``` java + final String invokedFunctionArn = context.getInvokedFunctionArn() + final VaultClient vaultClient = DefaultCerberusClientFactory.getClientForLambda(invokedFunctionArn); +``` + ## Further Details Cerberus client is a small project. It only has a few classes and they are all fully documented. For further details please see the source code, including javadocs and unit tests. diff --git a/gradle.properties b/gradle.properties index 77270e6..273f6f8 100644 --- a/gradle.properties +++ b/gradle.properties @@ -14,6 +14,6 @@ # limitations under the License. # -version=1.0.0 +version=1.1.0 groupId=com.nike artifactId=cerberus-client \ No newline at end of file diff --git a/gradle/dependencies.gradle b/gradle/dependencies.gradle index b5b67f8..a3c3219 100644 --- a/gradle/dependencies.gradle +++ b/gradle/dependencies.gradle @@ -39,6 +39,7 @@ dependencies { compile "org.slf4j:slf4j-api:1.7.14" compile "com.amazonaws:aws-java-sdk-core:1.10.50" compile "com.amazonaws:aws-java-sdk-kms:1.10.50" + compile "com.amazonaws:aws-java-sdk-lambda:1.10.50" testCompile "junit:junit:4.12" testCompile ("org.mockito:mockito-core:1.10.19") { diff --git a/src/main/java/com/nike/cerberus/client/DefaultCerberusClientFactory.java b/src/main/java/com/nike/cerberus/client/DefaultCerberusClientFactory.java index ffe512f..447814d 100644 --- a/src/main/java/com/nike/cerberus/client/DefaultCerberusClientFactory.java +++ b/src/main/java/com/nike/cerberus/client/DefaultCerberusClientFactory.java @@ -17,8 +17,12 @@ package com.nike.cerberus.client; import com.nike.cerberus.client.auth.DefaultCerberusCredentialsProviderChain; +import com.nike.cerberus.client.auth.EnvironmentCerberusCredentialsProvider; +import com.nike.cerberus.client.auth.SystemPropertyCerberusCredentialsProvider; +import com.nike.cerberus.client.auth.aws.LambdaRoleVaultCredentialsProvider; import com.nike.vault.client.VaultClient; import com.nike.vault.client.VaultClientFactory; +import com.nike.vault.client.auth.VaultCredentialsProviderChain; /** * Client factory for creating a Vault client with a URL resolver and credentials provider specific to Cerberus. @@ -35,4 +39,21 @@ public static VaultClient getClient() { return VaultClientFactory.getClient(new DefaultCerberusUrlResolver(), new DefaultCerberusCredentialsProviderChain()); } + + /** + * Creates a new {@link VaultClient} with the {@link DefaultCerberusUrlResolver} for URL resolving + * and a credentials provider chain that includes the {@link LambdaRoleVaultCredentialsProvider} for obtaining + * credentials. + * + * @param invokedFunctionArn The ARN for the AWS Lambda function being invoked. + * @return Vault client + */ + public static VaultClient getClientForLambda(final String invokedFunctionArn) { + final DefaultCerberusUrlResolver urlResolver = new DefaultCerberusUrlResolver(); + return VaultClientFactory.getClient(urlResolver, + new VaultCredentialsProviderChain( + new EnvironmentCerberusCredentialsProvider(), + new SystemPropertyCerberusCredentialsProvider(), + new LambdaRoleVaultCredentialsProvider(urlResolver, invokedFunctionArn))); + } } diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java index 2411c99..7cab28b 100644 --- a/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java +++ b/src/main/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProvider.java @@ -16,12 +16,13 @@ package com.nike.cerberus.client.auth.aws; +import com.amazonaws.regions.Region; import com.amazonaws.regions.Regions; import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.AWSKMSClient; import com.amazonaws.services.kms.model.DecryptRequest; import com.amazonaws.services.kms.model.DecryptResult; import com.amazonaws.util.Base64; -import com.amazonaws.util.EC2MetadataUtils; import com.google.gson.FieldNamingPolicy; import com.google.gson.Gson; import com.google.gson.GsonBuilder; @@ -43,6 +44,7 @@ import okhttp3.Response; import org.apache.commons.lang3.StringUtils; import org.joda.time.DateTime; +import org.joda.time.DateTimeZone; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -56,8 +58,6 @@ import java.util.Map; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.regex.Matcher; -import java.util.regex.Pattern; /** * {@link VaultCredentialsProvider} implementation that uses some AWS @@ -72,8 +72,6 @@ public abstract class BaseAwsCredentialsProvider implements VaultCredentialsProv private static final Logger LOGGER = LoggerFactory.getLogger(BaseAwsCredentialsProvider.class); - private final Pattern iamArnPattern = Pattern.compile("(arn\\:aws\\:iam\\:\\:)(?[0-9].*)(\\:.*)"); - private final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock(); private final Lock readLock = readWriteLock.readLock(); @@ -139,29 +137,38 @@ public VaultCredentials getCredentials() { abstract protected void authenticate(); /** - * Parses and returns the AWS account ID from the instance profile ARN. + * Authenticates with Cerberus and decrypts and sets the token and expiration details. * - * @return AWS account ID + * @param accountId + * AWS account ID used to auth with cerberus + * @param iamRole + * IAM role name used to auth with cerberus */ - protected String lookupAccountId() { - final EC2MetadataUtils.IAMInfo iamInfo = EC2MetadataUtils.getIAMInstanceProfileInfo(); - - if (iamInfo == null) { - final String errorMessage = "No IAM Instance Profile assigned to running instance."; - LOGGER.error(errorMessage); - throw new VaultClientException(errorMessage); - } + protected void getAndSetToken(final String accountId, final String iamRole) { + getAndSetToken(accountId, iamRole, Regions.getCurrentRegion()); + } - final Matcher matcher = iamArnPattern.matcher(iamInfo.instanceProfileArn); + /** + * Authenticates with Cerberus and decrypts and sets the token and expiration details. + * + * @param accountId + * AWS account ID used to auth with cerberus + * @param iamRole + * IAM role name used to auth with cerberus + * @param region + * AWS Region used in auth with cerberus + */ + protected void getAndSetToken(final String accountId, final String iamRole, final Region region) { + final AWSKMSClient kmsClient = new AWSKMSClient(); + kmsClient.setRegion(region); - if (matcher.matches()) { - final String accountId = matcher.group("accountId"); - if (StringUtils.isNotBlank(accountId)) { - return accountId; - } - } + final String encryptedAuthData = getEncryptedAuthData(accountId, iamRole, region); + final VaultAuthResponse decryptedToken = decryptToken(kmsClient, encryptedAuthData); + final DateTime expires = DateTime.now(DateTimeZone.UTC); + expires.plusSeconds(decryptedToken.getLeaseDuration() - paddingTimeInSeconds); - throw new VaultClientException("Unable to obtain AWS account ID from instance profile ARN."); + credentials = new TokenVaultCredentials(decryptedToken.getClientToken()); + expireDateTime = expires; } /** @@ -171,9 +178,11 @@ protected String lookupAccountId() { * AWS account ID used in the row key * @param roleName * IAM role name used in the row key + * @param region + * Current region of the running function or instance * @return Base64 and encrypted token */ - protected String getEncryptedAuthData(final String accountId, final String roleName) { + protected String getEncryptedAuthData(final String accountId, final String roleName, Region region) { final String url = urlResolver.resolve(); if (StringUtils.isBlank(url)) { @@ -189,7 +198,7 @@ protected String getEncryptedAuthData(final String accountId, final String roleN Request.Builder requestBuilder = new Request.Builder().url(url + "/v1/auth/iam-role") .addHeader(HttpHeader.ACCEPT, DEFAULT_MEDIA_TYPE.toString()) .addHeader(HttpHeader.CONTENT_TYPE, DEFAULT_MEDIA_TYPE.toString()) - .method(HttpMethod.POST, buildCredentialsRequestBody(accountId, roleName)); + .method(HttpMethod.POST, buildCredentialsRequestBody(accountId, roleName, region)); Response response = httpClient.newCall(requestBuilder.build()).execute(); @@ -240,11 +249,13 @@ protected VaultAuthResponse decryptToken(AWSKMS kmsClient, String encryptedToken return gson.fromJson(decryptedAuthData, VaultAuthResponse.class); } - private RequestBody buildCredentialsRequestBody(final String accountId, final String roleName) { + private RequestBody buildCredentialsRequestBody(final String accountId, final String roleName, Region region) { + final String regionName = region == null ? Regions.getCurrentRegion().getName() : region.getName(); + final Map credentials = new HashMap<>(); credentials.put("account_id", accountId); credentials.put("role_name", roleName); - credentials.put("region", Regions.getCurrentRegion().getName()); + credentials.put("region", regionName); return RequestBody.create(DEFAULT_MEDIA_TYPE, gson.toJson(credentials)); } diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/InstanceRoleVaultCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/InstanceRoleVaultCredentialsProvider.java index 080f02a..068e1fc 100644 --- a/src/main/java/com/nike/cerberus/client/auth/aws/InstanceRoleVaultCredentialsProvider.java +++ b/src/main/java/com/nike/cerberus/client/auth/aws/InstanceRoleVaultCredentialsProvider.java @@ -17,21 +17,18 @@ package com.nike.cerberus.client.auth.aws; import com.amazonaws.AmazonClientException; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.kms.AWSKMSClient; import com.amazonaws.util.EC2MetadataUtils; import com.google.gson.JsonSyntaxException; import com.nike.vault.client.UrlResolver; import com.nike.vault.client.VaultClientException; -import com.nike.vault.client.auth.TokenVaultCredentials; import com.nike.vault.client.auth.VaultCredentialsProvider; -import com.nike.vault.client.model.VaultAuthResponse; -import org.joda.time.DateTime; -import org.joda.time.DateTimeZone; +import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; /** @@ -45,6 +42,8 @@ public class InstanceRoleVaultCredentialsProvider extends BaseAwsCredentialsProv private static final Logger LOGGER = LoggerFactory.getLogger(InstanceRoleVaultCredentialsProvider.class); + public static final Pattern IAM_ARN_PATTERN = Pattern.compile("(arn\\:aws\\:iam\\:\\:)(?[0-9].*)(\\:.*)"); + /** * Constructor to setup credentials provider using the specified * implementation of {@link UrlResolver} @@ -68,24 +67,13 @@ protected void authenticate() { final Set iamRoleSet = EC2MetadataUtils.getIAMSecurityCredentials().keySet(); final String accountId = lookupAccountId(); - final AWSKMSClient kmsClient = new AWSKMSClient(); - kmsClient.setRegion(Regions.getCurrentRegion()); - for (final String iamRole : iamRoleSet) { try { - final String encryptedAuthData = getEncryptedAuthData(accountId, iamRole); - final VaultAuthResponse decryptedToken = decryptToken(kmsClient, encryptedAuthData); - final DateTime expires = DateTime.now(DateTimeZone.UTC); - expires.plusSeconds(decryptedToken.getLeaseDuration() - paddingTimeInSeconds); - - credentials = new TokenVaultCredentials(decryptedToken.getClientToken()); - expireDateTime = expires; - + getAndSetToken(accountId, iamRole); return; } catch (VaultClientException sce) { LOGGER.warn("Unable to acquire Vault token for IAM role: " + iamRole, sce); } - } } catch (AmazonClientException ace) { LOGGER.warn("Unexpected error communicating with AWS services.", ace); @@ -95,4 +83,30 @@ protected void authenticate() { throw new VaultClientException("Unable to acquire token with EC2 instance role."); } + + /** + * Parses and returns the AWS account ID from the instance profile ARN. + * + * @return AWS account ID + */ + protected String lookupAccountId() { + final EC2MetadataUtils.IAMInfo iamInfo = EC2MetadataUtils.getIAMInstanceProfileInfo(); + + if (iamInfo == null) { + final String errorMessage = "No IAM Instance Profile assigned to running instance."; + LOGGER.error(errorMessage); + throw new VaultClientException(errorMessage); + } + + final Matcher matcher = IAM_ARN_PATTERN.matcher(iamInfo.instanceProfileArn); + + if (matcher.matches()) { + final String accountId = matcher.group("accountId"); + if (StringUtils.isNotBlank(accountId)) { + return accountId; + } + } + + throw new VaultClientException("Unable to obtain AWS account ID from instance profile ARN."); + } } diff --git a/src/main/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProvider.java b/src/main/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProvider.java new file mode 100644 index 0000000..a949a6c --- /dev/null +++ b/src/main/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProvider.java @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2016 Nike, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nike.cerberus.client.auth.aws; + +import com.amazonaws.AmazonClientException; +import com.amazonaws.regions.Region; +import com.amazonaws.regions.RegionUtils; +import com.amazonaws.services.lambda.AWSLambda; +import com.amazonaws.services.lambda.AWSLambdaClient; +import com.amazonaws.services.lambda.model.GetFunctionConfigurationRequest; +import com.amazonaws.services.lambda.model.GetFunctionConfigurationResult; +import com.google.gson.JsonSyntaxException; +import com.nike.vault.client.UrlResolver; +import com.nike.vault.client.VaultClientException; +import com.nike.vault.client.auth.VaultCredentialsProvider; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * {@link VaultCredentialsProvider} implementation that uses the assigned role + * to lambda function to authenticate with Cerberus and decrypt the auth + * response using KMS. If the assigned role has been granted the appropriate + * provisioned for usage of Vault, it will succeed and have a token that can be + * used to interact with Vault. + */ +public class LambdaRoleVaultCredentialsProvider extends BaseAwsCredentialsProvider { + + public static final Logger LOGGER = LoggerFactory.getLogger(LambdaRoleVaultCredentialsProvider.class); + + public static final Pattern LAMBDA_FUNCTION_ARN_PATTERN = + Pattern.compile("arn:aws:lambda:(?[a-zA-Z0-9-]+):(?[0-9]{12}):function:(?[a-zA-Z0-9-_]+)(:(?.*))?"); + + public static final Pattern IAM_ROLE_ARN_PATTERN = + Pattern.compile("arn:aws:iam::(?\\d{12}):role/?(?[a-zA-Z_0-9+=,.@\\-_/]+)"); + + private final String functionName; + private final String qualifier; + private final String region; + + /** + * Constructor to setup credentials provider using the specified + * implementation of {@link UrlResolver} + * + * @param urlResolver Resolver for resolving the Cerberus URL + * @param invokedFunctionArn The invoked lambda function's ARN + */ + public LambdaRoleVaultCredentialsProvider(final UrlResolver urlResolver, final String invokedFunctionArn) { + super(urlResolver); + final Matcher matcher = LAMBDA_FUNCTION_ARN_PATTERN.matcher(invokedFunctionArn); + + if (!matcher.matches()) { + throw new IllegalArgumentException("invokedFunctionArn not a properly formatted lambda function ARN."); + } + + this.functionName = matcher.group("functionName"); + this.qualifier = matcher.group("qualifier"); + this.region = matcher.group("awsRegion"); + } + + /** + * Looks up the assigned role for the running Lambda via the GetFunctionConfiguration API. Requests a token from + * Cerberus and attempts to decrypt it as that role. + */ + @Override + protected void authenticate() { + final Region currentRegion = RegionUtils.getRegion(this.region); + + final AWSLambda lambdaClient = new AWSLambdaClient(); + lambdaClient.setRegion(currentRegion); + + final GetFunctionConfigurationResult functionConfiguration = lambdaClient.getFunctionConfiguration( + new GetFunctionConfigurationRequest() + .withFunctionName(functionName) + .withQualifier(qualifier)); + + final String roleArn = functionConfiguration.getRole(); + + if (StringUtils.isBlank(roleArn)) { + throw new IllegalStateException("Lambda function has no assigned role, aborting Cerberus authentication."); + } + + final Matcher roleArnMatcher = IAM_ROLE_ARN_PATTERN.matcher(roleArn); + + if (!roleArnMatcher.matches()) { + throw new IllegalStateException("Lambda function assigned role is not a valid IAM role ARN."); + } + + final String accountId = roleArnMatcher.group("accountId"); + final String iamRoleArn = roleArnMatcher.group("roleName"); + + try { + getAndSetToken(accountId, iamRoleArn, currentRegion); + return; + } catch (AmazonClientException ace) { + LOGGER.warn("Unexpected error communicating with AWS services.", ace); + } catch (JsonSyntaxException jse) { + LOGGER.error("The decrypted auth response was not in the expected format!", jse); + } catch (VaultClientException sce) { + LOGGER.warn("Unable to acquire Vault token for IAM role: " + iamRoleArn, sce); + } + + throw new VaultClientException("Unable to acquire token with Lambda instance role."); + } + + +} diff --git a/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java b/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java new file mode 100644 index 0000000..a2584e3 --- /dev/null +++ b/src/test/java/com/nike/cerberus/client/auth/aws/BaseAwsCredentialsProviderTest.java @@ -0,0 +1,104 @@ +package com.nike.cerberus.client.auth.aws; + +import com.amazonaws.regions.Region; +import com.amazonaws.regions.RegionUtils; +import com.amazonaws.services.kms.AWSKMSClient; +import com.nike.cerberus.client.DefaultCerberusUrlResolver; +import com.nike.vault.client.UrlResolver; +import com.nike.vault.client.VaultClientException; +import com.nike.vault.client.VaultServerException; +import com.nike.vault.client.auth.VaultCredentials; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.reset; +import static org.powermock.api.mockito.PowerMockito.mock; +import static org.powermock.api.mockito.PowerMockito.when; + +public class BaseAwsCredentialsProviderTest extends BaseCredentialsProviderTest{ + public static final Region REGION = RegionUtils.getRegion("us-west-2"); + public static final String ACCOUNT_ID = "123456789012"; + public static final String CERBERUS_TEST_ROLE = "cerberus-test-role"; + public static final String ERROR_RESPONSE = "Error calling vault"; + + protected static final String MISSING_AUTH_DATA = "{}"; + + + private BaseAwsCredentialsProvider provider; + private UrlResolver urlResolver; + private String vaultUrl; + private MockWebServer mockWebServer; + + @Before + public void setUp() throws Exception { + urlResolver = mock(UrlResolver.class); + + provider = new TestAwsCredentialsProvider(urlResolver); + + mockWebServer = new MockWebServer(); + mockWebServer.start(); + + vaultUrl = "http://localhost:" + mockWebServer.getPort(); + } + + @After + public void tearDown() throws Exception { + reset(urlResolver); + } + + @Test(expected = VaultClientException.class) + public void getEncryptedAuthData_blank_url_throws_exception() throws Exception { + when(urlResolver.resolve()).thenReturn(""); + + provider.getEncryptedAuthData(ACCOUNT_ID, CERBERUS_TEST_ROLE, REGION); + } + + @Test(expected = VaultClientException.class) + public void decryptToken_throws_exception_when_non_encrypted_data_provided() { + provider.decryptToken(mock(AWSKMSClient.class), "non-encrypted-token"); + } + + @Test(expected = VaultServerException.class) + public void getEncryptedAuthData_throws_exception_on_bad_response_code() throws IOException { + when(urlResolver.resolve()).thenReturn(vaultUrl); + + System.setProperty(DefaultCerberusUrlResolver.CERBERUS_ADDR_SYS_PROPERTY, vaultUrl); + mockWebServer.enqueue(new MockResponse().setResponseCode(400).setBody(ERROR_RESPONSE)); + + provider.getEncryptedAuthData(ACCOUNT_ID, CERBERUS_TEST_ROLE, REGION); + } + + @Test(expected = VaultClientException.class) + public void getEncryptedAuthData_throws_exception_on_missing_auth_data() throws IOException { + when(urlResolver.resolve()).thenReturn(vaultUrl); + + System.setProperty(DefaultCerberusUrlResolver.CERBERUS_ADDR_SYS_PROPERTY, vaultUrl); + mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(MISSING_AUTH_DATA)); + + provider.getEncryptedAuthData(ACCOUNT_ID, CERBERUS_TEST_ROLE, REGION); + } + + class TestAwsCredentialsProvider extends BaseAwsCredentialsProvider { + /** + * Constructor to setup credentials provider using the specified + * implementation of {@link UrlResolver} + * + * @param urlResolver Resolver for resolving the Cerberus URL + */ + public TestAwsCredentialsProvider(UrlResolver urlResolver) { + super(urlResolver); + } + + @Override + protected void authenticate() { + + } + } + +} \ No newline at end of file diff --git a/src/test/java/com/nike/cerberus/client/auth/aws/BaseCredentialsProviderTest.java b/src/test/java/com/nike/cerberus/client/auth/aws/BaseCredentialsProviderTest.java new file mode 100644 index 0000000..ffba193 --- /dev/null +++ b/src/test/java/com/nike/cerberus/client/auth/aws/BaseCredentialsProviderTest.java @@ -0,0 +1,23 @@ +package com.nike.cerberus.client.auth.aws; + +import com.amazonaws.services.kms.AWSKMSClient; +import com.amazonaws.services.kms.model.DecryptRequest; +import com.amazonaws.services.kms.model.DecryptResult; + +import java.nio.ByteBuffer; + +import static org.mockito.Matchers.any; +import static org.powermock.api.mockito.PowerMockito.when; + +public class BaseCredentialsProviderTest { + protected static final String AUTH_RESPONSE = "{\"auth_data\":\"eyJjbGllbnRfdG9rZW4iOiI2NjMyY2I1Zi1mMTBjLTQ1NzItOTU0NS1lNTJmNDdmNmEzZmQiLCAibGVhc2VfZHVyYXRpb24iOiIzNjAwIn0=\"}"; + protected static final String BAD_AUTH_RESPONSE_JSON = "{,\"auth_data\":\"eyJjbGllbnRfdG9rZW4iOiI2NjMyY2I1Zi1mMTBjLTQ1NzItOTU0NS1lNTJmNDdmNmEzZmQiLCAibGVhc2VfZHVyYXRpb24iOiIzNjAwIn0=\"}"; + protected static final String DECODED_AUTH_DATA = "{\"client_token\":\"6632cb5f-f10c-4572-9545-e52f47f6a3fd\", \"lease_duration\":\"3600\"}"; + protected static final String AUTH_TOKEN = "6632cb5f-f10c-4572-9545-e52f47f6a3fd"; + + protected void mockDecrypt(AWSKMSClient kmsClient, final String toDecrypt) { + DecryptResult decryptResult = new DecryptResult(); + decryptResult.setPlaintext(ByteBuffer.wrap(toDecrypt.getBytes())); + when(kmsClient.decrypt(any(DecryptRequest.class))).thenReturn(decryptResult); + } +} diff --git a/src/test/java/com/nike/cerberus/client/auth/aws/InstanceRoleVaultCredentialsProviderTest.java b/src/test/java/com/nike/cerberus/client/auth/aws/InstanceRoleVaultCredentialsProviderTest.java index 4867653..a4bd8af 100644 --- a/src/test/java/com/nike/cerberus/client/auth/aws/InstanceRoleVaultCredentialsProviderTest.java +++ b/src/test/java/com/nike/cerberus/client/auth/aws/InstanceRoleVaultCredentialsProviderTest.java @@ -19,8 +19,6 @@ import com.amazonaws.AmazonClientException; import com.amazonaws.regions.Regions; import com.amazonaws.services.kms.AWSKMSClient; -import com.amazonaws.services.kms.model.DecryptRequest; -import com.amazonaws.services.kms.model.DecryptResult; import com.amazonaws.util.EC2MetadataUtils; import com.nike.cerberus.client.DefaultCerberusUrlResolver; import com.nike.vault.client.UrlResolver; @@ -36,7 +34,6 @@ import org.powermock.modules.junit4.PowerMockRunner; import java.io.IOException; -import java.nio.ByteBuffer; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -55,13 +52,7 @@ @PrepareForTest({AWSKMSClient.class, EC2MetadataUtils.class, InstanceRoleVaultCredentialsProvider.class}) @PowerMockIgnore({"javax.management.*","javax.net.*"}) -public class InstanceRoleVaultCredentialsProviderTest { - - private static final String AUTH_RESPONSE = "{\"auth_data\":\"eyJjbGllbnRfdG9rZW4iOiI2NjMyY2I1Zi1mMTBjLTQ1NzItOTU0NS1lNTJmNDdmNmEzZmQiLCAibGVhc2VfZHVyYXRpb24iOiIzNjAwIn0=\"}"; - - private static final String DECODED_AUTH_DATA = "{\"client_token\":\"6632cb5f-f10c-4572-9545-e52f47f6a3fd\", \"lease_duration\":\"3600\"}"; - - private static final String AUTH_TOKEN = "6632cb5f-f10c-4572-9545-e52f47f6a3fd"; +public class InstanceRoleVaultCredentialsProviderTest extends BaseCredentialsProviderTest { private static final String GOOD_INSTANCE_PROFILE_ARN = "arn:aws:iam::107274433934:instance-profile/rawr"; @@ -93,7 +84,7 @@ public void getCredentials_returns_valid_credentials() throws IOException { mockGetIamSecurityCredentials(DEFAULT_ROLE); mockGetIamInstanceProfileInfo(GOOD_INSTANCE_PROFILE_ARN); - mockDecrypt(DECODED_AUTH_DATA); + mockDecrypt(kmsClient, DECODED_AUTH_DATA); when(urlResolver.resolve()).thenReturn(vaultUrl); System.setProperty(DefaultCerberusUrlResolver.CERBERUS_ADDR_SYS_PROPERTY, vaultUrl); @@ -151,9 +142,4 @@ private void mockGetIamInstanceProfileInfo(final String instanceProfileArn) { when(EC2MetadataUtils.getIAMInstanceProfileInfo()).thenReturn(iamInfo); } - private void mockDecrypt(final String toDecrypt) { - DecryptResult decryptResult = new DecryptResult(); - decryptResult.setPlaintext(ByteBuffer.wrap(toDecrypt.getBytes())); - when(kmsClient.decrypt(any(DecryptRequest.class))).thenReturn(decryptResult); - } } \ No newline at end of file diff --git a/src/test/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProviderTest.java b/src/test/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProviderTest.java new file mode 100644 index 0000000..667fcbb --- /dev/null +++ b/src/test/java/com/nike/cerberus/client/auth/aws/LambdaRoleVaultCredentialsProviderTest.java @@ -0,0 +1,138 @@ +package com.nike.cerberus.client.auth.aws; + +import com.amazonaws.regions.RegionUtils; +import com.amazonaws.regions.Regions; +import com.amazonaws.services.kms.AWSKMSClient; +import com.amazonaws.services.lambda.AWSLambdaClient; +import com.amazonaws.services.lambda.model.GetFunctionConfigurationRequest; +import com.amazonaws.services.lambda.model.GetFunctionConfigurationResult; +import com.nike.cerberus.client.DefaultCerberusUrlResolver; +import com.nike.vault.client.UrlResolver; +import com.nike.vault.client.VaultClientException; +import com.nike.vault.client.auth.VaultCredentials; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.mockito.PowerMockito; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.*; +import static org.powermock.api.mockito.PowerMockito.mock; +import static org.powermock.api.mockito.PowerMockito.*; +import static org.powermock.api.mockito.PowerMockito.when; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({AWSKMSClient.class, Regions.class, AWSLambdaClient.class, LambdaRoleVaultCredentialsProvider.class}) +@PowerMockIgnore({"javax.management.*", "javax.net.*"}) +public class LambdaRoleVaultCredentialsProviderTest extends BaseCredentialsProviderTest { + private static final String VALID_LAMBDA_ARN = "arn:aws:lambda:us-west-2:123456789012:function:lambda-test:1.1.0"; + private static final String VALID_LAMBDA_ARN_NO_QUALIFIER = "arn:aws:lambda:us-west-2:012345678912:function:lambda-test"; + private static final String VALID_IAM_ARN = "arn:aws:iam::123456789012:role/cerberus-role"; + private static final String INVALID_ARN = "invalid-arn"; + + private AWSKMSClient kmsClient; + private UrlResolver urlResolver; + private AWSLambdaClient lambdaClient; + private MockWebServer mockWebServer; + private String vaultUrl; + + @Before + public void setup() throws Exception { + kmsClient = mock(AWSKMSClient.class); + urlResolver = mock(UrlResolver.class); + lambdaClient = mock(AWSLambdaClient.class); + + mockWebServer = new MockWebServer(); + mockWebServer.start(); + vaultUrl = "http://localhost:" + mockWebServer.getPort(); + + when(urlResolver.resolve()).thenReturn(vaultUrl); + + + mockStatic(Regions.class); + + when(Regions.getCurrentRegion()).thenReturn(RegionUtils.getRegion("us-west-2")); + whenNew(AWSLambdaClient.class).withNoArguments().thenReturn(lambdaClient); + whenNew(AWSKMSClient.class).withAnyArguments().thenReturn(kmsClient); + } + + @Test(expected = IllegalArgumentException.class) + public void provider_creation_fails_on_invalid_arn() { + LambdaRoleVaultCredentialsProvider provider = new LambdaRoleVaultCredentialsProvider(urlResolver, "invalid-lambda-arn"); + } + + @Test + public void valid_arn_and_no_qualifier_matched_properly_on_provider_creation() { + LambdaRoleVaultCredentialsProvider provider = new LambdaRoleVaultCredentialsProvider(urlResolver, VALID_LAMBDA_ARN_NO_QUALIFIER); + } + + @Test + public void getCredentials_returns_valid_creds() throws Exception { + final LambdaRoleVaultCredentialsProvider provider = PowerMockito.spy(new LambdaRoleVaultCredentialsProvider(urlResolver, VALID_LAMBDA_ARN)); + final GetFunctionConfigurationRequest request = new GetFunctionConfigurationRequest().withFunctionName("lambda-test").withQualifier("1.1.0"); + + when(urlResolver.resolve()).thenReturn(vaultUrl); + + System.setProperty(DefaultCerberusUrlResolver.CERBERUS_ADDR_SYS_PROPERTY, vaultUrl); + mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTH_RESPONSE)); + + mockDecrypt(kmsClient, DECODED_AUTH_DATA); + + when(lambdaClient.getFunctionConfiguration(request)).thenReturn(new GetFunctionConfigurationResult().withRole(VALID_IAM_ARN)); + + final VaultCredentials credentials = provider.getCredentials(); + + assertThat(credentials.getToken()).isEqualTo(AUTH_TOKEN); + verify(lambdaClient, times(1)).getFunctionConfiguration(request); + } + + @Test(expected = VaultClientException.class) + public void VaultClientException_thrown_when_bad_json_returned() throws Exception { + final LambdaRoleVaultCredentialsProvider provider = PowerMockito.spy(new LambdaRoleVaultCredentialsProvider(urlResolver, VALID_LAMBDA_ARN)); + final GetFunctionConfigurationRequest request = new GetFunctionConfigurationRequest().withFunctionName("lambda-test").withQualifier("1.1.0"); + + + System.setProperty(DefaultCerberusUrlResolver.CERBERUS_ADDR_SYS_PROPERTY, vaultUrl); + mockWebServer.enqueue(new MockResponse().setResponseCode(200).setBody(BAD_AUTH_RESPONSE_JSON)); + + mockDecrypt(kmsClient, DECODED_AUTH_DATA); + + when(lambdaClient.getFunctionConfiguration(request)).thenReturn(new GetFunctionConfigurationResult().withRole(VALID_IAM_ARN)); + + provider.getCredentials(); + } + + @Test(expected = IllegalStateException.class) + public void authenticate_fails_when_lambda_has_invalid_assigned_role() throws Exception { + final LambdaRoleVaultCredentialsProvider provider = new LambdaRoleVaultCredentialsProvider(urlResolver, VALID_LAMBDA_ARN); + final GetFunctionConfigurationRequest request = new GetFunctionConfigurationRequest().withFunctionName("lambda-test").withQualifier("1.1.0"); + + when(lambdaClient.getFunctionConfiguration(request)).thenReturn(new GetFunctionConfigurationResult().withRole(INVALID_ARN)); + + provider.authenticate(); + } + + + @Test(expected = IllegalStateException.class) + public void authenticate_fails_when_lambda_has_no_assigned_role() throws Exception { + final LambdaRoleVaultCredentialsProvider provider = new LambdaRoleVaultCredentialsProvider(urlResolver, VALID_LAMBDA_ARN); + final GetFunctionConfigurationRequest request = new GetFunctionConfigurationRequest().withFunctionName("lambda-test").withQualifier("1.1.0"); + + when(lambdaClient.getFunctionConfiguration(request)).thenReturn(new GetFunctionConfigurationResult().withRole("")); + + provider.authenticate(); + } + + @After + public void resetMocks() { + reset(lambdaClient, kmsClient); + } + + +} \ No newline at end of file