From 3f8b67e0520abe1a1f63ca9b8b18deae7f72914f Mon Sep 17 00:00:00 2001 From: Shaun Ford Date: Thu, 27 Apr 2017 13:43:26 -0700 Subject: [PATCH] Refactor KMS policy validation timeout logic --- .../service/AuthenticationService.java | 16 +-- .../com/nike/cerberus/service/KmsService.java | 97 ++++++++++++------ .../service/AuthenticationServiceTest.java | 21 ++-- .../nike/cerberus/service/KmsServiceTest.java | 98 +++++++++++++++---- 4 files changed, 157 insertions(+), 75 deletions(-) diff --git a/src/main/java/com/nike/cerberus/service/AuthenticationService.java b/src/main/java/com/nike/cerberus/service/AuthenticationService.java index 87308b4cb..574e6cf26 100644 --- a/src/main/java/com/nike/cerberus/service/AuthenticationService.java +++ b/src/main/java/com/nike/cerberus/service/AuthenticationService.java @@ -85,10 +85,8 @@ public class AuthenticationService { public static final String ADMIN_IAM_ROLES_PROPERTY = "cms.admin.roles"; public static final String USER_TOKEN_TTL_OVERRIDE = "cms.user.token.ttl.override"; public static final String IAM_TOKEN_TTL_OVERRIDE = "cms.iam.token.ttl.override"; - public static final String KMS_POLICY_VALIDATION_INTERVAL_OVERRIDE = "cms.kms.policy.validation.interval.millis.override"; public static final String LOOKUP_SELF_POLICY = "lookup-self"; public static final String DEFAULT_TOKEN_TTL = "1h"; - public static final Integer DEFAULT_KMS_VALIDATION_INTERVAL = 6000; // in milliseconds private final SafeDepositBoxDao safeDepositBoxDao; private final AwsIamRoleDao awsIamRoleDao; @@ -116,10 +114,6 @@ public class AuthenticationService { @Named(IAM_TOKEN_TTL_OVERRIDE) String iamTokenTTL = DEFAULT_TOKEN_TTL; - @Inject(optional=true) - @Named(KMS_POLICY_VALIDATION_INTERVAL_OVERRIDE) - Integer kmsKeyPolicyValidationInterval = DEFAULT_KMS_VALIDATION_INTERVAL; - @Inject public AuthenticationService(final SafeDepositBoxDao safeDepositBoxDao, final AwsIamRoleDao awsIamRoleDao, @@ -396,15 +390,7 @@ protected String getKeyId(IamPrincipalCredentials credentials) { } else { kmsKeyRecord = kmsKey.get(); kmsKeyId = kmsKeyRecord.getAwsKmsKeyId(); - String keyRegion = credentials.getRegion(); - if (ChronoUnit.MILLIS.between(kmsKeyRecord.getLastValidatedTs(), now) >= kmsKeyPolicyValidationInterval) { - try { - kmsService.validatePolicy(kmsKeyId, credentials.getIamPrincipalArn(), keyRegion); - kmsService.updateKmsKey(kmsKeyRecord, SYSTEM_USER, now, now); - } catch (ApiException ae) { - logger.warn("Could not validate KMS policy. API limit may have been reached for validate call."); - } - } + kmsService.validatePolicy(kmsKeyRecord, credentials.getIamPrincipalArn()); } return kmsKeyId; diff --git a/src/main/java/com/nike/cerberus/service/KmsService.java b/src/main/java/com/nike/cerberus/service/KmsService.java index 819b225d8..418475792 100644 --- a/src/main/java/com/nike/cerberus/service/KmsService.java +++ b/src/main/java/com/nike/cerberus/service/KmsService.java @@ -26,11 +26,13 @@ import com.amazonaws.services.kms.model.KeyMetadata; import com.amazonaws.services.kms.model.KeyUsageType; import com.amazonaws.services.kms.model.PutKeyPolicyRequest; +import com.google.inject.name.Named; import com.nike.backstopper.exception.ApiException; import com.nike.cerberus.aws.KmsClientFactory; import com.nike.cerberus.dao.AwsIamRoleDao; import com.nike.cerberus.error.DefaultApiError; import com.nike.cerberus.record.AwsIamRoleKmsKeyRecord; +import com.nike.cerberus.util.DateTimeSupplier; import com.nike.cerberus.util.UuidSupplier; import org.mybatis.guice.transactional.Transactional; import org.slf4j.Logger; @@ -39,8 +41,11 @@ import javax.inject.Inject; import javax.inject.Singleton; import java.time.OffsetDateTime; +import java.time.temporal.ChronoUnit; import java.util.Optional; +import static com.nike.cerberus.service.AuthenticationService.SYSTEM_USER; + /** * Abstracts interactions with the AWS KMS service. */ @@ -50,6 +55,9 @@ public class KmsService { private final Logger logger = LoggerFactory.getLogger(this.getClass()); private static final String KMS_ALIAS_FORMAT = "alias/cerberus/%s"; + public static final String KMS_POLICY_VALIDATION_INTERVAL_OVERRIDE = "cms.kms.policy.validation.interval.millis.override"; + public static final Integer DEFAULT_KMS_VALIDATION_INTERVAL = 6000; // in milliseconds + private final AwsIamRoleDao awsIamRoleDao; @@ -59,15 +67,23 @@ public class KmsService { private final KmsPolicyService kmsPolicyService; + private final DateTimeSupplier dateTimeSupplier; + + @com.google.inject.Inject(optional=true) + @Named(KMS_POLICY_VALIDATION_INTERVAL_OVERRIDE) + Integer kmsKeyPolicyValidationInterval = DEFAULT_KMS_VALIDATION_INTERVAL; + @Inject public KmsService(final AwsIamRoleDao awsIamRoleDao, final UuidSupplier uuidSupplier, final KmsClientFactory kmsClientFactory, - final KmsPolicyService kmsPolicyService) { + final KmsPolicyService kmsPolicyService, + final DateTimeSupplier dateTimeSupplier) { this.awsIamRoleDao = awsIamRoleDao; this.uuidSupplier = uuidSupplier; this.kmsClientFactory = kmsClientFactory; this.kmsPolicyService = kmsPolicyService; + this.dateTimeSupplier = dateTimeSupplier; } /** @@ -119,13 +135,21 @@ public String provisionKmsKey(final String iamRoleId, return result.getKeyMetadata().getArn(); } + /** + * Updates the KMS CMK record for the specified IAM role and region + * @param awsIamRoleId The IAM role that this CMK will be associated with + * @param awsRegion The region to provision the key in + * @param user The user requesting it + * @param lastedUpdatedTs The date when the record was last updated + * @param lastValidatedTs The date when the record was last validated + */ @Transactional - public void updateKmsKey(final AwsIamRoleKmsKeyRecord awsIamRoleKmsKeyRecord, + public void updateKmsKey(final String awsIamRoleId, + final String awsRegion, final String user, - final OffsetDateTime dateTime, + final OffsetDateTime lastedUpdatedTs, final OffsetDateTime lastValidatedTs) { - final Optional kmsKey = - awsIamRoleDao.getKmsKey(awsIamRoleKmsKeyRecord.getAwsIamRoleId(), awsIamRoleKmsKeyRecord.getAwsRegion()); + final Optional kmsKey = awsIamRoleDao.getKmsKey(awsIamRoleId, awsRegion); if (!kmsKey.isPresent()) { throw ApiException.newBuilder() @@ -139,7 +163,7 @@ public void updateKmsKey(final AwsIamRoleKmsKeyRecord awsIamRoleKmsKeyRecord, AwsIamRoleKmsKeyRecord updatedKmsKeyRecord = new AwsIamRoleKmsKeyRecord(); updatedKmsKeyRecord.setAwsIamRoleId(kmsKeyRecord.getAwsIamRoleId()); updatedKmsKeyRecord.setLastUpdatedBy(user); - updatedKmsKeyRecord.setLastUpdatedTs(dateTime); + updatedKmsKeyRecord.setLastUpdatedTs(lastedUpdatedTs); updatedKmsKeyRecord.setLastValidatedTs(lastValidatedTs); updatedKmsKeyRecord.setAwsRegion(kmsKeyRecord.getAwsRegion()); awsIamRoleDao.updateIamRoleKmsKey(updatedKmsKeyRecord); @@ -162,34 +186,51 @@ protected String getAliasName(String awsIamRoleKmsKeyId) { * statement has been deleted the ARN is replaced by the ID. We can validate that principal matches an ARN pattern * or recreate the policy. * - * @param keyId - The CMK Id to validate the policies on. + * @param kmsKeyRecord - The CMK record to validate policy on * @param iamPrincipalArn - The principal ARN that should have decrypt permission - * @param kmsCMKRegion - The region that the key was provisioned for */ - public void validatePolicy(String keyId, String iamPrincipalArn, String kmsCMKRegion) { + public void validatePolicy(AwsIamRoleKmsKeyRecord kmsKeyRecord, String iamPrincipalArn) { + + if (! kmsPolicyNeedsValidation(kmsKeyRecord)) { + return; + } + + String kmsCMKRegion = kmsKeyRecord.getAwsRegion(); + String awsKmsKeyArn = kmsKeyRecord.getAwsKmsKeyId(); AWSKMSClient kmsClient = kmsClientFactory.getClient(kmsCMKRegion); - GetKeyPolicyResult policyResult = null; try { - policyResult = kmsClient.getKeyPolicy(new GetKeyPolicyRequest().withKeyId(keyId).withPolicyName("default")); + GetKeyPolicyResult policyResult = kmsClient.getKeyPolicy(new GetKeyPolicyRequest().withKeyId(awsKmsKeyArn).withPolicyName("default")); + + if (!kmsPolicyService.isPolicyValid(policyResult.getPolicy(), iamPrincipalArn)) { + logger.info("The KMS key: {} generated for IAM principal: {} contained an invalid policy, regenerating", + awsKmsKeyArn, iamPrincipalArn); + String updatedPolicy = kmsPolicyService.generateStandardKmsPolicy(iamPrincipalArn); + kmsClient.putKeyPolicy(new PutKeyPolicyRequest() + .withKeyId(awsKmsKeyArn) + .withPolicyName("default") + .withPolicy(updatedPolicy) + ); + } + + // update last validated timestamp + OffsetDateTime now = dateTimeSupplier.get(); + updateKmsKey(kmsKeyRecord.getAwsIamRoleId(), kmsCMKRegion, SYSTEM_USER, now, now); } catch (AmazonServiceException e) { - throw ApiException.newBuilder() - .withApiErrors(DefaultApiError.FAILED_TO_VALIDATE_KMS_KEY_POLICY) - .withExceptionCause(e) - .withExceptionMessage( - String.format("Failed to validate KMS key policy for keyId: " + - "%s for IAM principal: %s in region: %s", keyId, iamPrincipalArn, kmsCMKRegion)) - .build(); + logger.warn(String.format("Failed to validate KMS policy for keyId: %s for IAM principal: %s in region: %s. API limit" + + " may have been reached for validate call.", awsKmsKeyArn, iamPrincipalArn, kmsCMKRegion), e); } + } - if (!kmsPolicyService.isPolicyValid(policyResult.getPolicy(), iamPrincipalArn)) { - logger.info("The KMS key: {} generated for IAM principal: {} contained an invalid policy, regenerating", - keyId, iamPrincipalArn); - String updatedPolicy = kmsPolicyService.generateStandardKmsPolicy(iamPrincipalArn); - kmsClient.putKeyPolicy(new PutKeyPolicyRequest() - .withKeyId(keyId) - .withPolicyName("default") - .withPolicy(updatedPolicy) - ); - } + /** + * Determines if given KMS policy should be validated + * @param kmsKeyRecord - KMS key record to check for validation + * @return True if needs validation, False if not + */ + protected boolean kmsPolicyNeedsValidation(AwsIamRoleKmsKeyRecord kmsKeyRecord) { + + OffsetDateTime now = dateTimeSupplier.get(); + long timeSinceLastValidatedInMillis = ChronoUnit.MILLIS.between(kmsKeyRecord.getLastValidatedTs(), now); + + return timeSinceLastValidatedInMillis >= kmsKeyPolicyValidationInterval; } } diff --git a/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java b/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java index d87540443..be1b993f2 100644 --- a/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java +++ b/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java @@ -121,8 +121,8 @@ public void tests_that_getKeyId_only_validates_kms_policy_one_time_within_interv String principalArn = "principal arn"; String region = "region"; String iamRoleId = "iam role id"; - String kmsId = "kms id"; - String keyId = "key id"; + String kmsKeyId = "kms id"; + String cmkId = "key id"; // ensure that validate interval is passed OffsetDateTime dateTime = OffsetDateTime.of(2016, 1, 1, 1, 1, 1, 1, ZoneOffset.UTC); @@ -138,26 +138,19 @@ public void tests_that_getKeyId_only_validates_kms_policy_one_time_within_interv when(awsIamRoleDao.getIamRole(principalArn)).thenReturn(Optional.of(awsIamRoleRecord)); AwsIamRoleKmsKeyRecord awsIamRoleKmsKeyRecord = new AwsIamRoleKmsKeyRecord(); - awsIamRoleKmsKeyRecord.setId(kmsId); - awsIamRoleKmsKeyRecord.setAwsKmsKeyId(keyId); + awsIamRoleKmsKeyRecord.setId(kmsKeyId); + awsIamRoleKmsKeyRecord.setAwsKmsKeyId(cmkId); awsIamRoleKmsKeyRecord.setLastValidatedTs(dateTime); when(awsIamRoleDao.getKmsKey(iamRoleId, region)).thenReturn(Optional.of(awsIamRoleKmsKeyRecord)); when(dateTimeSupplier.get()).thenReturn(now); - authenticationService.getKeyId(iamPrincipalCredentials); + String result = authenticationService.getKeyId(iamPrincipalCredentials); // verify validate is called once interval has passed - verify(kmsService, times(1)).validatePolicy(keyId, principalArn, region); - - // reset interval - awsIamRoleKmsKeyRecord.setLastValidatedTs(now); - - // verify validate is not called when interval has not passed - authenticationService.getKeyId(iamPrincipalCredentials); - authenticationService.getKeyId(iamPrincipalCredentials); - verify(kmsService, times(1)).validatePolicy(keyId, principalArn, region); + assertEquals(cmkId, result); + verify(kmsService, times(1)).validatePolicy(awsIamRoleKmsKeyRecord, principalArn); } } \ No newline at end of file diff --git a/src/test/java/com/nike/cerberus/service/KmsServiceTest.java b/src/test/java/com/nike/cerberus/service/KmsServiceTest.java index fc19d08bb..575ca9242 100644 --- a/src/test/java/com/nike/cerberus/service/KmsServiceTest.java +++ b/src/test/java/com/nike/cerberus/service/KmsServiceTest.java @@ -1,16 +1,19 @@ package com.nike.cerberus.service; +import com.amazonaws.AmazonServiceException; import com.amazonaws.services.kms.AWSKMSClient; import com.amazonaws.services.kms.model.*; import com.nike.backstopper.exception.ApiException; import com.nike.cerberus.aws.KmsClientFactory; import com.nike.cerberus.dao.AwsIamRoleDao; import com.nike.cerberus.record.AwsIamRoleKmsKeyRecord; +import com.nike.cerberus.util.DateTimeSupplier; import com.nike.cerberus.util.UuidSupplier; import org.junit.Before; import org.junit.Test; import java.time.OffsetDateTime; +import java.time.ZoneOffset; import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -22,6 +25,7 @@ public class KmsServiceTest { private UuidSupplier uuidSupplier; private KmsClientFactory kmsClientFactory; private KmsPolicyService kmsPolicyService; + private DateTimeSupplier dateTimeSupplier; private KmsService kmsService; @@ -31,7 +35,9 @@ public void setup() { uuidSupplier = mock(UuidSupplier.class); kmsClientFactory = mock(KmsClientFactory.class); kmsPolicyService = mock(KmsPolicyService.class); - kmsService = new KmsService(awsIamRoleDao, uuidSupplier, kmsClientFactory, kmsPolicyService); + dateTimeSupplier = mock(DateTimeSupplier.class); + + kmsService = new KmsService(awsIamRoleDao, uuidSupplier, kmsClientFactory, kmsPolicyService, dateTimeSupplier); } @Test @@ -93,24 +99,88 @@ public void test_getAliasName() { } @Test - public void test_validatePolicy() { + public void test_validatePolicy_validates_policy_when_validate_interval_has_passed() { + String kmsKeyArn = "kms key arn"; + String awsIamRoleRecordId = "aws iam role record id"; + String kmsCMKRegion = "kmsCMKRegion"; + String policy = "policy"; + OffsetDateTime lastValidated = OffsetDateTime.of(2016, 1, 1, 1, 1, + 1, 1, ZoneOffset.UTC); + OffsetDateTime now = OffsetDateTime.now(); + + AWSKMSClient client = mock(AWSKMSClient.class); + when(kmsClientFactory.getClient(kmsCMKRegion)).thenReturn(client); + + GetKeyPolicyResult result = mock(GetKeyPolicyResult.class); + when(result.getPolicy()).thenReturn(policy); + when(client.getKeyPolicy(new GetKeyPolicyRequest().withKeyId(kmsKeyArn) + .withPolicyName("default"))).thenReturn(result); + when(kmsPolicyService.isPolicyValid(policy, kmsKeyArn)).thenReturn(true); + + AwsIamRoleKmsKeyRecord kmsKey = mock(AwsIamRoleKmsKeyRecord.class); + when(kmsKey.getAwsIamRoleId()).thenReturn(awsIamRoleRecordId); + when(kmsKey.getAwsKmsKeyId()).thenReturn(kmsKeyArn); + when(kmsKey.getAwsRegion()).thenReturn(kmsCMKRegion); + when(kmsKey.getLastValidatedTs()).thenReturn(lastValidated); + when(awsIamRoleDao.getKmsKey(awsIamRoleRecordId, kmsCMKRegion)).thenReturn(Optional.of(kmsKey)); + + when(dateTimeSupplier.get()).thenReturn(now); + kmsService.validatePolicy(kmsKey, kmsKeyArn); + + verify(client, times(1)).getKeyPolicy(new GetKeyPolicyRequest().withKeyId(kmsKeyArn) + .withPolicyName("default")); + verify(kmsPolicyService, times(1)).isPolicyValid(policy, kmsKeyArn); + } + + @Test + public void test_validatePolicy_validates_policy_when_validate_interval_has_not_passed() { + String awsKmsKeyArn = "aws kms key arn"; + String iamPrincipalArn = "arn"; + String awsIamRoleRecordId = "aws iam role record id"; + String kmsCMKRegion = "kmsCMKRegion"; + OffsetDateTime now = OffsetDateTime.now(); + + AwsIamRoleKmsKeyRecord kmsKey = mock(AwsIamRoleKmsKeyRecord.class); + when(kmsKey.getAwsKmsKeyId()).thenReturn(awsKmsKeyArn); + when(kmsKey.getAwsIamRoleId()).thenReturn(awsIamRoleRecordId); + when(kmsKey.getAwsRegion()).thenReturn(kmsCMKRegion); + when(kmsKey.getLastValidatedTs()).thenReturn(now); + + when(dateTimeSupplier.get()).thenReturn(now); + kmsService.validatePolicy(kmsKey, iamPrincipalArn); + + verify(kmsClientFactory, never()).getClient(anyString()); + verify(kmsPolicyService, never()).isPolicyValid(anyString(), anyString()); + } + + @Test + public void test_validatePolicy_does_not_throw_error_when_cannot_validate() { String keyId = "key-id"; - String iamRoleArn = "arn"; + String iamPrincipalArn = "arn"; String kmsCMKRegion = "kmsCMKRegion"; String policy = "policy"; + OffsetDateTime lastValidated = OffsetDateTime.of(2016, 1, 1, 1, 1, + 1, 1, ZoneOffset.UTC); + OffsetDateTime now = OffsetDateTime.now(); + when(dateTimeSupplier.get()).thenReturn(now); + + AwsIamRoleKmsKeyRecord kmsKey = mock(AwsIamRoleKmsKeyRecord.class); + when(kmsKey.getAwsKmsKeyId()).thenReturn(keyId); + when(kmsKey.getAwsIamRoleId()).thenReturn(iamPrincipalArn); + when(kmsKey.getAwsRegion()).thenReturn(kmsCMKRegion); + when(kmsKey.getLastValidatedTs()).thenReturn(lastValidated); AWSKMSClient client = mock(AWSKMSClient.class); when(kmsClientFactory.getClient(kmsCMKRegion)).thenReturn(client); GetKeyPolicyResult result = mock(GetKeyPolicyResult.class); when(result.getPolicy()).thenReturn(policy); - when(client.getKeyPolicy(new GetKeyPolicyRequest().withKeyId(keyId).withPolicyName("default"))).thenReturn(result); - when(kmsPolicyService.isPolicyValid(policy, iamRoleArn)).thenReturn(true); + when(client.getKeyPolicy(new GetKeyPolicyRequest().withKeyId(keyId).withPolicyName("default"))).thenThrow(AmazonServiceException.class); - kmsService.validatePolicy(keyId, iamRoleArn, kmsCMKRegion); + kmsService.validatePolicy(kmsKey, iamPrincipalArn); - verify(client, times(1)).getKeyPolicy(new GetKeyPolicyRequest().withKeyId(keyId).withPolicyName("default")); - verify(kmsPolicyService, times(1)).isPolicyValid(policy, iamRoleArn); + verify(kmsPolicyService, never()).isPolicyValid(policy, iamPrincipalArn); + verify(client, never()).putKeyPolicy(anyObject()); } @Test @@ -121,17 +191,13 @@ public void test_updateKmsKey() { String user = "user"; OffsetDateTime dateTime = OffsetDateTime.now(); - AwsIamRoleKmsKeyRecord awsIamRoleKmsKeyRecord = new AwsIamRoleKmsKeyRecord(); - awsIamRoleKmsKeyRecord.setAwsRegion(awsRegion); - awsIamRoleKmsKeyRecord.setAwsIamRoleId(iamRoleId); - AwsIamRoleKmsKeyRecord dbRecord = new AwsIamRoleKmsKeyRecord(); dbRecord.setAwsRegion(awsRegion); dbRecord.setAwsIamRoleId(iamRoleId); dbRecord.setLastValidatedTs(OffsetDateTime.now()); when(awsIamRoleDao.getKmsKey(iamRoleId, awsRegion)).thenReturn(Optional.of(dbRecord)); - kmsService.updateKmsKey(awsIamRoleKmsKeyRecord, user, dateTime, dateTime); + kmsService.updateKmsKey(iamRoleId, awsRegion, user, dateTime, dateTime); AwsIamRoleKmsKeyRecord expected = new AwsIamRoleKmsKeyRecord(); expected.setAwsIamRoleId(iamRoleId); @@ -151,12 +217,8 @@ public void test_updateKmsKey_fails_when_record_not_found() { String user = "user"; OffsetDateTime dateTime = OffsetDateTime.now(); - AwsIamRoleKmsKeyRecord awsIamRoleKmsKeyRecord = new AwsIamRoleKmsKeyRecord(); - awsIamRoleKmsKeyRecord.setAwsRegion(awsRegion); - awsIamRoleKmsKeyRecord.setAwsIamRoleId(iamRoleId); - when(awsIamRoleDao.getKmsKey(iamRoleId, awsRegion)).thenReturn(Optional.empty()); - kmsService.updateKmsKey(awsIamRoleKmsKeyRecord, user, dateTime, dateTime); + kmsService.updateKmsKey(iamRoleId, awsRegion, user, dateTime, dateTime); } } \ No newline at end of file