Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Add caching to KMS auth (#211)
Browse files Browse the repository at this point in the history
Add caching to KMS auth
  • Loading branch information
mayitbeegh authored Oct 2, 2019
1 parent a4d90e9 commit 96b6dc1
Show file tree
Hide file tree
Showing 9 changed files with 285 additions and 7 deletions.
3 changes: 2 additions & 1 deletion gradle/dependencies.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ dependencies {
"com.okta.authn.sdk:okta-authn-sdk-api:0.1.0",
"com.okta.sdk:okta-sdk-httpclient:1.2.0",
"com.okta.authn.sdk:okta-authn-sdk-impl:0.1.0",
"org.reflections:reflections:0.9.11"
"org.reflections:reflections:0.9.11",
"com.github.ben-manes.caffeine:caffeine:2.8.0"

)

Expand Down
144 changes: 144 additions & 0 deletions src/main/java/com/nike/cerberus/cache/MetricReportingCache.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright (c) 2019 Nike, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package com.nike.cerberus.cache;

import com.codahale.metrics.Counter;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Policy;
import com.github.benmanes.caffeine.cache.stats.CacheStats;
import com.nike.cerberus.service.MetricsService;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import static com.github.benmanes.caffeine.cache.Caffeine.newBuilder;

/**
* A simple Caffeine backed cache that auto expires items after a certain time period,
* to help us against bursty traffic that does repeat reads.
*/
public class MetricReportingCache<K, V> implements Cache<K, V> {

private final Logger log = LoggerFactory.getLogger(getClass());

private com.github.benmanes.caffeine.cache.Cache<K, V> delegate;
private final Counter hitCounter;
private final Counter missCounter;

public MetricReportingCache(String namespace,int expireTimeInSeconds, MetricsService metricsService,
Map<String, String> dimensions) {
log.info("Cerberus cache with namespace: {} has been initialized with ttl: {}", namespace, expireTimeInSeconds);

delegate = newBuilder()
.expireAfterWrite(expireTimeInSeconds, TimeUnit.SECONDS)
.build();

// Create Metrics for this cache.
hitCounter = metricsService.getOrCreateCounter(String.format("cms.cache.%s.hit", namespace), dimensions);
missCounter = metricsService.getOrCreateCounter(String.format("cms.cache.%s.miss", namespace), dimensions);
metricsService.getOrCreateLongCallbackGauge(String.format("cms.cache.%s.size", namespace),
() -> delegate.estimatedSize(), dimensions);
metricsService.getOrCreateLongCallbackGauge(String.format("cms.cache.%s.stats.totalHitCount", namespace),
() -> delegate.stats().hitCount(), dimensions);
metricsService.getOrCreateLongCallbackGauge(String.format("cms.cache.%s.stats.totalMissCount", namespace),
() -> delegate.stats().missCount(), dimensions);
}

@Override
public V getIfPresent(Object key) {
V value = delegate.getIfPresent(key);
if (value == null) {
missCounter.inc();
} else {
hitCounter.inc();
}
return value;
}

@Override
public V get(K key, Function<? super K, ? extends V> mappingFunction) {
V value = delegate.getIfPresent(key);
if (value == null) {
missCounter.inc();
return delegate.get(key, mappingFunction);
} else {
hitCounter.inc();
return value;
}
}

@Override
public void put(K key, V value) {
delegate.put(key, value);
}

@Override
public void putAll(Map map) {
delegate.putAll(map);
}

@Override
public void invalidate(Object key) {
delegate.invalidate(key);
}

@Override
public void invalidateAll() {
delegate.invalidateAll();
}

@Override
public long estimatedSize() {
return delegate.estimatedSize();
}

@Override
public CacheStats stats() {
return delegate.stats();
}

@Override
public ConcurrentMap asMap() {
return delegate.asMap();
}

@Override
public void cleanUp() {
delegate.cleanUp();
}

@Override
public Policy policy() {
return delegate.policy();
}

@Override
public void invalidateAll(Iterable keys) {

}

@Override
public @NonNull Map getAllPresent(Iterable keys) {
return delegate.getAllPresent(keys);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import javax.validation.constraints.Pattern;

import java.util.Objects;

import static com.nike.cerberus.util.AwsIamRoleArnParser.AWS_IAM_PRINCIPAL_ARN_REGEX_ROLE_GENERATION;

/**
Expand Down Expand Up @@ -49,4 +51,18 @@ public String getRegion() {
public void setRegion(String region) {
this.region = region;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IamPrincipalCredentials that = (IamPrincipalCredentials) o;
return Objects.equals(iamPrincipalArn, that.iamPrincipalArn) &&
Objects.equals(region, that.region);
}

@Override
public int hashCode() {
return Objects.hash(iamPrincipalArn, region);
}
}
61 changes: 61 additions & 0 deletions src/main/java/com/nike/cerberus/metrics/CallbackLongGauge.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2019 Nike, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package com.nike.cerberus.metrics;

import com.codahale.metrics.Gauge;
import com.codahale.metrics.Metric;
import com.signalfx.codahale.metrics.MetricBuilder;

import java.util.function.Supplier;

public class CallbackLongGauge implements Metric, Gauge<Long> {

private final Supplier<Long> supplier;

public CallbackLongGauge(Supplier<Long> supplier) {
this.supplier = supplier;
}

@Override
public Long getValue() {
return supplier.get();
}

public static class Builder implements MetricBuilder<CallbackLongGauge> {

private final Supplier<Long> supplier;

public static Builder getInstance(Supplier<Long> supplier) {
return new Builder(supplier);
}

private Builder(Supplier<Long> supplier) {
this.supplier = supplier;
}

@Override
public CallbackLongGauge newMetric() {
return new CallbackLongGauge(supplier);
}

@Override
public boolean isInstance(Metric metric) {
return metric instanceof CallbackLongGauge;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
import com.amazonaws.regions.Region;
import com.amazonaws.regions.Regions;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.benmanes.caffeine.cache.Cache;
import com.google.inject.*;
import com.google.inject.name.Names;
import com.nike.backstopper.apierror.projectspecificinfo.ProjectApiErrors;
import com.nike.cerberus.auth.connector.AuthConnector;
import com.nike.cerberus.aws.KmsClientFactory;
import com.nike.cerberus.cache.MetricReportingCache;
import com.nike.cerberus.domain.IamPrincipalCredentials;
import com.nike.cerberus.domain.IamRoleAuthResponse;
import com.nike.cerberus.endpoints.*;
import com.nike.cerberus.endpoints.authentication.*;
import com.nike.cerberus.endpoints.authentication.CodeHandlingMfaCheck;
Expand Down Expand Up @@ -74,6 +78,7 @@
import java.util.stream.Collectors;

import static com.nike.cerberus.service.EncryptionService.*;
import static com.github.benmanes.caffeine.cache.Caffeine.newBuilder;

public class CmsGuiceModule extends AbstractModule {

Expand Down Expand Up @@ -366,9 +371,19 @@ public Region currentRegion() {
return currentRegion;
}

@Provides
@Singleton
@Named("kmsAuthCache")
public Cache<IamPrincipalCredentials, IamRoleAuthResponse> kmsAuthCache(MetricsService metricsService,
KmsAuthCachingOptionalPropertyHolder kmsAuthCachingOptionalPropertyHolder) {
return new MetricReportingCache("auth.kms",
kmsAuthCachingOptionalPropertyHolder.maxAge,
metricsService, null);
}

/**
* This 'holder' class allows optional injection of KMS-data-key-caching-specific properties that are only necessary when
* SignalFx metrics reporting is enabled.
* KMS data key caching is enabled.
*
* The 'optional=true' parameter to Guice @Inject cannot be used in combination with the @Provides annotation
* or with constructor injection.
Expand Down Expand Up @@ -396,4 +411,19 @@ static class KmsDataKeyCachingOptionalPropertyHolder {
@com.google.inject.name.Named("cms.encryption.cache.decrypt.maxAgeInSeconds")
int decryptMaxAge = 0;
}

/**
* This 'holder' class allows optional injection of KMS-auth-caching-specific properties that are only necessary when
* KMS-auth caching is enabled.
*
* The 'optional=true' parameter to Guice @Inject cannot be used in combination with the @Provides annotation
* or with constructor injection.
*
* https://github.com/google/guice/wiki/FrequentlyAskedQuestions
*/
static class KmsAuthCachingOptionalPropertyHolder {
@Inject(optional=true)
@com.google.inject.name.Named("cms.iam.token.cache.maxAgeInSeconds")
int maxAge = 10;
}
}
24 changes: 20 additions & 4 deletions src/main/java/com/nike/cerberus/service/AuthenticationService.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.amazonaws.services.kms.model.KMSInvalidStateException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.benmanes.caffeine.cache.Cache;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -98,6 +99,8 @@ public class AuthenticationService {
public static final String MAX_TOKEN_REFRESH_COUNT = "cms.user.token.maxRefreshCount";
public static final String USER_TOKEN_TTL = "cms.user.token.ttl";
public static final String IAM_TOKEN_TTL = "cms.iam.token.ttl";
public static final String CACHE_ENABLED = "cms.auth.iam.kms.cache.enabled";
public static final String CACHE = "kmsAuthCache";
public static final String LOOKUP_SELF_POLICY = "lookup-self";
public static final int KMS_SIZE_LIMIT = 4096;

Expand All @@ -116,6 +119,8 @@ public class AuthenticationService {
private final String iamTokenTTL;
private final AwsIamRoleService awsIamRoleService;
private final int maxTokenRefreshCount;
private final boolean cacheEnabled;
private final Cache<IamPrincipalCredentials, IamRoleAuthResponse> cache;

@Inject(optional=true)
@Named(ADMIN_IAM_ROLES_PROPERTY)
Expand All @@ -138,8 +143,9 @@ public AuthenticationService(SafeDepositBoxDao safeDepositBoxDao,
AuthTokenService authTokenService,
@Named(USER_TOKEN_TTL) String userTokenTTL,
@Named(IAM_TOKEN_TTL) String iamTokenTTL,
AwsIamRoleService awsIamRoleService) {

AwsIamRoleService awsIamRoleService,
@Named(CACHE_ENABLED) boolean cacheEnabled,
@Named(CACHE) Cache<IamPrincipalCredentials, IamRoleAuthResponse> cache) {
this.safeDepositBoxDao = safeDepositBoxDao;
this.awsIamRoleDao = awsIamRoleDao;
this.authServiceConnector = authConnector;
Expand All @@ -155,6 +161,8 @@ public AuthenticationService(SafeDepositBoxDao safeDepositBoxDao,
this.userTokenTTL = userTokenTTL;
this.iamTokenTTL = iamTokenTTL;
this.awsIamRoleService = awsIamRoleService;
this.cacheEnabled = cacheEnabled;
this.cache = cache;
}

/**
Expand Down Expand Up @@ -228,7 +236,7 @@ public IamRoleAuthResponse authenticate(IamRoleCredentials credentials) {
authPrincipalMetadata.put(CerberusPrincipal.METADATA_KEY_AWS_ACCOUNT_ID, awsIamRoleArnParser.getAccountId(iamPrincipalArn));
authPrincipalMetadata.put(CerberusPrincipal.METADATA_KEY_AWS_IAM_ROLE_NAME, awsIamRoleArnParser.getRoleName(iamPrincipalArn));

return authenticate(iamPrincipalCredentials, authPrincipalMetadata);
return cachingKmsAuthenticate(iamPrincipalCredentials, authPrincipalMetadata);
}

public IamRoleAuthResponse authenticate(IamPrincipalCredentials credentials) {
Expand All @@ -237,7 +245,7 @@ public IamRoleAuthResponse authenticate(IamPrincipalCredentials credentials) {
final Map<String, String> authPrincipalMetadata = generateCommonIamPrincipalAuthMetadata(iamPrincipalArn, credentials.getRegion());
authPrincipalMetadata.put(CerberusPrincipal.METADATA_KEY_AWS_IAM_PRINCIPAL_ARN, iamPrincipalArn);

return authenticate(credentials, authPrincipalMetadata);
return cachingKmsAuthenticate(credentials, authPrincipalMetadata);
}

/**
Expand All @@ -257,6 +265,14 @@ public AuthTokenResponse stsAuthenticate(final String iamPrincipalArn) {
return authResponse;
}

private IamRoleAuthResponse cachingKmsAuthenticate(IamPrincipalCredentials credentials, Map<String, String> authPrincipalMetadata) {
if (cacheEnabled){
return cache.get(credentials, key -> authenticate(credentials, authPrincipalMetadata));
} else {
return authenticate(credentials, authPrincipalMetadata);
}
}

private IamRoleAuthResponse authenticate(IamPrincipalCredentials credentials, Map<String, String> authPrincipalMetadata) {
final AwsIamRoleKmsKeyRecord kmsKeyRecord;
final AwsIamRoleRecord iamRoleRecord;
Expand Down
Loading

0 comments on commit 96b6dc1

Please sign in to comment.