Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Refactor KMS policy validation timeout logic
Browse files Browse the repository at this point in the history
  • Loading branch information
sdford committed Apr 27, 2017
1 parent e62de48 commit 3f8b67e
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,8 @@ public class AuthenticationService {
public static final String ADMIN_IAM_ROLES_PROPERTY = "cms.admin.roles";
public static final String USER_TOKEN_TTL_OVERRIDE = "cms.user.token.ttl.override";
public static final String IAM_TOKEN_TTL_OVERRIDE = "cms.iam.token.ttl.override";
public static final String KMS_POLICY_VALIDATION_INTERVAL_OVERRIDE = "cms.kms.policy.validation.interval.millis.override";
public static final String LOOKUP_SELF_POLICY = "lookup-self";
public static final String DEFAULT_TOKEN_TTL = "1h";
public static final Integer DEFAULT_KMS_VALIDATION_INTERVAL = 6000; // in milliseconds

private final SafeDepositBoxDao safeDepositBoxDao;
private final AwsIamRoleDao awsIamRoleDao;
Expand Down Expand Up @@ -116,10 +114,6 @@ public class AuthenticationService {
@Named(IAM_TOKEN_TTL_OVERRIDE)
String iamTokenTTL = DEFAULT_TOKEN_TTL;

@Inject(optional=true)
@Named(KMS_POLICY_VALIDATION_INTERVAL_OVERRIDE)
Integer kmsKeyPolicyValidationInterval = DEFAULT_KMS_VALIDATION_INTERVAL;

@Inject
public AuthenticationService(final SafeDepositBoxDao safeDepositBoxDao,
final AwsIamRoleDao awsIamRoleDao,
Expand Down Expand Up @@ -396,15 +390,7 @@ protected String getKeyId(IamPrincipalCredentials credentials) {
} else {
kmsKeyRecord = kmsKey.get();
kmsKeyId = kmsKeyRecord.getAwsKmsKeyId();
String keyRegion = credentials.getRegion();
if (ChronoUnit.MILLIS.between(kmsKeyRecord.getLastValidatedTs(), now) >= kmsKeyPolicyValidationInterval) {
try {
kmsService.validatePolicy(kmsKeyId, credentials.getIamPrincipalArn(), keyRegion);
kmsService.updateKmsKey(kmsKeyRecord, SYSTEM_USER, now, now);
} catch (ApiException ae) {
logger.warn("Could not validate KMS policy. API limit may have been reached for validate call.");
}
}
kmsService.validatePolicy(kmsKeyRecord, credentials.getIamPrincipalArn());
}

return kmsKeyId;
Expand Down
97 changes: 69 additions & 28 deletions src/main/java/com/nike/cerberus/service/KmsService.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
import com.amazonaws.services.kms.model.KeyMetadata;
import com.amazonaws.services.kms.model.KeyUsageType;
import com.amazonaws.services.kms.model.PutKeyPolicyRequest;
import com.google.inject.name.Named;
import com.nike.backstopper.exception.ApiException;
import com.nike.cerberus.aws.KmsClientFactory;
import com.nike.cerberus.dao.AwsIamRoleDao;
import com.nike.cerberus.error.DefaultApiError;
import com.nike.cerberus.record.AwsIamRoleKmsKeyRecord;
import com.nike.cerberus.util.DateTimeSupplier;
import com.nike.cerberus.util.UuidSupplier;
import org.mybatis.guice.transactional.Transactional;
import org.slf4j.Logger;
Expand All @@ -39,8 +41,11 @@
import javax.inject.Inject;
import javax.inject.Singleton;
import java.time.OffsetDateTime;
import java.time.temporal.ChronoUnit;
import java.util.Optional;

import static com.nike.cerberus.service.AuthenticationService.SYSTEM_USER;

/**
* Abstracts interactions with the AWS KMS service.
*/
Expand All @@ -50,6 +55,9 @@ public class KmsService {
private final Logger logger = LoggerFactory.getLogger(this.getClass());

private static final String KMS_ALIAS_FORMAT = "alias/cerberus/%s";
public static final String KMS_POLICY_VALIDATION_INTERVAL_OVERRIDE = "cms.kms.policy.validation.interval.millis.override";
public static final Integer DEFAULT_KMS_VALIDATION_INTERVAL = 6000; // in milliseconds


private final AwsIamRoleDao awsIamRoleDao;

Expand All @@ -59,15 +67,23 @@ public class KmsService {

private final KmsPolicyService kmsPolicyService;

private final DateTimeSupplier dateTimeSupplier;

@com.google.inject.Inject(optional=true)
@Named(KMS_POLICY_VALIDATION_INTERVAL_OVERRIDE)
Integer kmsKeyPolicyValidationInterval = DEFAULT_KMS_VALIDATION_INTERVAL;

@Inject
public KmsService(final AwsIamRoleDao awsIamRoleDao,
final UuidSupplier uuidSupplier,
final KmsClientFactory kmsClientFactory,
final KmsPolicyService kmsPolicyService) {
final KmsPolicyService kmsPolicyService,
final DateTimeSupplier dateTimeSupplier) {
this.awsIamRoleDao = awsIamRoleDao;
this.uuidSupplier = uuidSupplier;
this.kmsClientFactory = kmsClientFactory;
this.kmsPolicyService = kmsPolicyService;
this.dateTimeSupplier = dateTimeSupplier;
}

/**
Expand Down Expand Up @@ -119,13 +135,21 @@ public String provisionKmsKey(final String iamRoleId,
return result.getKeyMetadata().getArn();
}

/**
* Updates the KMS CMK record for the specified IAM role and region
* @param awsIamRoleId The IAM role that this CMK will be associated with
* @param awsRegion The region to provision the key in
* @param user The user requesting it
* @param lastedUpdatedTs The date when the record was last updated
* @param lastValidatedTs The date when the record was last validated
*/
@Transactional
public void updateKmsKey(final AwsIamRoleKmsKeyRecord awsIamRoleKmsKeyRecord,
public void updateKmsKey(final String awsIamRoleId,
final String awsRegion,
final String user,
final OffsetDateTime dateTime,
final OffsetDateTime lastedUpdatedTs,
final OffsetDateTime lastValidatedTs) {
final Optional<AwsIamRoleKmsKeyRecord> kmsKey =
awsIamRoleDao.getKmsKey(awsIamRoleKmsKeyRecord.getAwsIamRoleId(), awsIamRoleKmsKeyRecord.getAwsRegion());
final Optional<AwsIamRoleKmsKeyRecord> kmsKey = awsIamRoleDao.getKmsKey(awsIamRoleId, awsRegion);

if (!kmsKey.isPresent()) {
throw ApiException.newBuilder()
Expand All @@ -139,7 +163,7 @@ public void updateKmsKey(final AwsIamRoleKmsKeyRecord awsIamRoleKmsKeyRecord,
AwsIamRoleKmsKeyRecord updatedKmsKeyRecord = new AwsIamRoleKmsKeyRecord();
updatedKmsKeyRecord.setAwsIamRoleId(kmsKeyRecord.getAwsIamRoleId());
updatedKmsKeyRecord.setLastUpdatedBy(user);
updatedKmsKeyRecord.setLastUpdatedTs(dateTime);
updatedKmsKeyRecord.setLastUpdatedTs(lastedUpdatedTs);
updatedKmsKeyRecord.setLastValidatedTs(lastValidatedTs);
updatedKmsKeyRecord.setAwsRegion(kmsKeyRecord.getAwsRegion());
awsIamRoleDao.updateIamRoleKmsKey(updatedKmsKeyRecord);
Expand All @@ -162,34 +186,51 @@ protected String getAliasName(String awsIamRoleKmsKeyId) {
* statement has been deleted the ARN is replaced by the ID. We can validate that principal matches an ARN pattern
* or recreate the policy.
*
* @param keyId - The CMK Id to validate the policies on.
* @param kmsKeyRecord - The CMK record to validate policy on
* @param iamPrincipalArn - The principal ARN that should have decrypt permission
* @param kmsCMKRegion - The region that the key was provisioned for
*/
public void validatePolicy(String keyId, String iamPrincipalArn, String kmsCMKRegion) {
public void validatePolicy(AwsIamRoleKmsKeyRecord kmsKeyRecord, String iamPrincipalArn) {

if (! kmsPolicyNeedsValidation(kmsKeyRecord)) {
return;
}

String kmsCMKRegion = kmsKeyRecord.getAwsRegion();
String awsKmsKeyArn = kmsKeyRecord.getAwsKmsKeyId();
AWSKMSClient kmsClient = kmsClientFactory.getClient(kmsCMKRegion);
GetKeyPolicyResult policyResult = null;
try {
policyResult = kmsClient.getKeyPolicy(new GetKeyPolicyRequest().withKeyId(keyId).withPolicyName("default"));
GetKeyPolicyResult policyResult = kmsClient.getKeyPolicy(new GetKeyPolicyRequest().withKeyId(awsKmsKeyArn).withPolicyName("default"));

if (!kmsPolicyService.isPolicyValid(policyResult.getPolicy(), iamPrincipalArn)) {
logger.info("The KMS key: {} generated for IAM principal: {} contained an invalid policy, regenerating",
awsKmsKeyArn, iamPrincipalArn);
String updatedPolicy = kmsPolicyService.generateStandardKmsPolicy(iamPrincipalArn);
kmsClient.putKeyPolicy(new PutKeyPolicyRequest()
.withKeyId(awsKmsKeyArn)
.withPolicyName("default")
.withPolicy(updatedPolicy)
);
}

// update last validated timestamp
OffsetDateTime now = dateTimeSupplier.get();
updateKmsKey(kmsKeyRecord.getAwsIamRoleId(), kmsCMKRegion, SYSTEM_USER, now, now);
} catch (AmazonServiceException e) {
throw ApiException.newBuilder()
.withApiErrors(DefaultApiError.FAILED_TO_VALIDATE_KMS_KEY_POLICY)
.withExceptionCause(e)
.withExceptionMessage(
String.format("Failed to validate KMS key policy for keyId: " +
"%s for IAM principal: %s in region: %s", keyId, iamPrincipalArn, kmsCMKRegion))
.build();
logger.warn(String.format("Failed to validate KMS policy for keyId: %s for IAM principal: %s in region: %s. API limit" +
" may have been reached for validate call.", awsKmsKeyArn, iamPrincipalArn, kmsCMKRegion), e);
}
}

if (!kmsPolicyService.isPolicyValid(policyResult.getPolicy(), iamPrincipalArn)) {
logger.info("The KMS key: {} generated for IAM principal: {} contained an invalid policy, regenerating",
keyId, iamPrincipalArn);
String updatedPolicy = kmsPolicyService.generateStandardKmsPolicy(iamPrincipalArn);
kmsClient.putKeyPolicy(new PutKeyPolicyRequest()
.withKeyId(keyId)
.withPolicyName("default")
.withPolicy(updatedPolicy)
);
}
/**
* Determines if given KMS policy should be validated
* @param kmsKeyRecord - KMS key record to check for validation
* @return True if needs validation, False if not
*/
protected boolean kmsPolicyNeedsValidation(AwsIamRoleKmsKeyRecord kmsKeyRecord) {

OffsetDateTime now = dateTimeSupplier.get();
long timeSinceLastValidatedInMillis = ChronoUnit.MILLIS.between(kmsKeyRecord.getLastValidatedTs(), now);

return timeSinceLastValidatedInMillis >= kmsKeyPolicyValidationInterval;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ public void tests_that_getKeyId_only_validates_kms_policy_one_time_within_interv
String principalArn = "principal arn";
String region = "region";
String iamRoleId = "iam role id";
String kmsId = "kms id";
String keyId = "key id";
String kmsKeyId = "kms id";
String cmkId = "key id";

// ensure that validate interval is passed
OffsetDateTime dateTime = OffsetDateTime.of(2016, 1, 1, 1, 1, 1, 1, ZoneOffset.UTC);
Expand All @@ -138,26 +138,19 @@ public void tests_that_getKeyId_only_validates_kms_policy_one_time_within_interv
when(awsIamRoleDao.getIamRole(principalArn)).thenReturn(Optional.of(awsIamRoleRecord));

AwsIamRoleKmsKeyRecord awsIamRoleKmsKeyRecord = new AwsIamRoleKmsKeyRecord();
awsIamRoleKmsKeyRecord.setId(kmsId);
awsIamRoleKmsKeyRecord.setAwsKmsKeyId(keyId);
awsIamRoleKmsKeyRecord.setId(kmsKeyId);
awsIamRoleKmsKeyRecord.setAwsKmsKeyId(cmkId);
awsIamRoleKmsKeyRecord.setLastValidatedTs(dateTime);

when(awsIamRoleDao.getKmsKey(iamRoleId, region)).thenReturn(Optional.of(awsIamRoleKmsKeyRecord));

when(dateTimeSupplier.get()).thenReturn(now);

authenticationService.getKeyId(iamPrincipalCredentials);
String result = authenticationService.getKeyId(iamPrincipalCredentials);

// verify validate is called once interval has passed
verify(kmsService, times(1)).validatePolicy(keyId, principalArn, region);

// reset interval
awsIamRoleKmsKeyRecord.setLastValidatedTs(now);

// verify validate is not called when interval has not passed
authenticationService.getKeyId(iamPrincipalCredentials);
authenticationService.getKeyId(iamPrincipalCredentials);
verify(kmsService, times(1)).validatePolicy(keyId, principalArn, region);
assertEquals(cmkId, result);
verify(kmsService, times(1)).validatePolicy(awsIamRoleKmsKeyRecord, principalArn);
}

}
Loading

0 comments on commit 3f8b67e

Please sign in to comment.