From 96b6dc16dc41e2d7aef4a7315829753c7d3023f6 Mon Sep 17 00:00:00 2001 From: Sean Lin Date: Wed, 2 Oct 2019 11:18:32 -0700 Subject: [PATCH] Add caching to KMS auth (#211) Add caching to KMS auth --- gradle/dependencies.gradle | 3 +- .../cerberus/cache/MetricReportingCache.java | 144 ++++++++++++++++++ .../domain/IamPrincipalCredentials.java | 16 ++ .../cerberus/metrics/CallbackLongGauge.java | 61 ++++++++ .../server/config/guice/CmsGuiceModule.java | 32 +++- .../service/AuthenticationService.java | 24 ++- .../nike/cerberus/service/MetricsService.java | 7 + src/main/resources/cms.conf | 1 + .../service/AuthenticationServiceTest.java | 4 +- 9 files changed, 285 insertions(+), 7 deletions(-) create mode 100644 src/main/java/com/nike/cerberus/cache/MetricReportingCache.java create mode 100644 src/main/java/com/nike/cerberus/metrics/CallbackLongGauge.java diff --git a/gradle/dependencies.gradle b/gradle/dependencies.gradle index b41f182d3..cc297acfe 100644 --- a/gradle/dependencies.gradle +++ b/gradle/dependencies.gradle @@ -76,7 +76,8 @@ dependencies { "com.okta.authn.sdk:okta-authn-sdk-api:0.1.0", "com.okta.sdk:okta-sdk-httpclient:1.2.0", "com.okta.authn.sdk:okta-authn-sdk-impl:0.1.0", - "org.reflections:reflections:0.9.11" + "org.reflections:reflections:0.9.11", + "com.github.ben-manes.caffeine:caffeine:2.8.0" ) diff --git a/src/main/java/com/nike/cerberus/cache/MetricReportingCache.java b/src/main/java/com/nike/cerberus/cache/MetricReportingCache.java new file mode 100644 index 000000000..0ec6413a2 --- /dev/null +++ b/src/main/java/com/nike/cerberus/cache/MetricReportingCache.java @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2019 Nike, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.nike.cerberus.cache; + +import com.codahale.metrics.Counter; +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Policy; +import com.github.benmanes.caffeine.cache.stats.CacheStats; +import com.nike.cerberus.service.MetricsService; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import static com.github.benmanes.caffeine.cache.Caffeine.newBuilder; + +/** + * A simple Caffeine backed cache that auto expires items after a certain time period, + * to help us against bursty traffic that does repeat reads. + */ +public class MetricReportingCache implements Cache { + + private final Logger log = LoggerFactory.getLogger(getClass()); + + private com.github.benmanes.caffeine.cache.Cache delegate; + private final Counter hitCounter; + private final Counter missCounter; + + public MetricReportingCache(String namespace,int expireTimeInSeconds, MetricsService metricsService, + Map dimensions) { + log.info("Cerberus cache with namespace: {} has been initialized with ttl: {}", namespace, expireTimeInSeconds); + + delegate = newBuilder() + .expireAfterWrite(expireTimeInSeconds, TimeUnit.SECONDS) + .build(); + + // Create Metrics for this cache. + hitCounter = metricsService.getOrCreateCounter(String.format("cms.cache.%s.hit", namespace), dimensions); + missCounter = metricsService.getOrCreateCounter(String.format("cms.cache.%s.miss", namespace), dimensions); + metricsService.getOrCreateLongCallbackGauge(String.format("cms.cache.%s.size", namespace), + () -> delegate.estimatedSize(), dimensions); + metricsService.getOrCreateLongCallbackGauge(String.format("cms.cache.%s.stats.totalHitCount", namespace), + () -> delegate.stats().hitCount(), dimensions); + metricsService.getOrCreateLongCallbackGauge(String.format("cms.cache.%s.stats.totalMissCount", namespace), + () -> delegate.stats().missCount(), dimensions); + } + + @Override + public V getIfPresent(Object key) { + V value = delegate.getIfPresent(key); + if (value == null) { + missCounter.inc(); + } else { + hitCounter.inc(); + } + return value; + } + + @Override + public V get(K key, Function mappingFunction) { + V value = delegate.getIfPresent(key); + if (value == null) { + missCounter.inc(); + return delegate.get(key, mappingFunction); + } else { + hitCounter.inc(); + return value; + } + } + + @Override + public void put(K key, V value) { + delegate.put(key, value); + } + + @Override + public void putAll(Map map) { + delegate.putAll(map); + } + + @Override + public void invalidate(Object key) { + delegate.invalidate(key); + } + + @Override + public void invalidateAll() { + delegate.invalidateAll(); + } + + @Override + public long estimatedSize() { + return delegate.estimatedSize(); + } + + @Override + public CacheStats stats() { + return delegate.stats(); + } + + @Override + public ConcurrentMap asMap() { + return delegate.asMap(); + } + + @Override + public void cleanUp() { + delegate.cleanUp(); + } + + @Override + public Policy policy() { + return delegate.policy(); + } + + @Override + public void invalidateAll(Iterable keys) { + + } + + @Override + public @NonNull Map getAllPresent(Iterable keys) { + return delegate.getAllPresent(keys); + } +} \ No newline at end of file diff --git a/src/main/java/com/nike/cerberus/domain/IamPrincipalCredentials.java b/src/main/java/com/nike/cerberus/domain/IamPrincipalCredentials.java index 1ec274160..aa7ff365d 100644 --- a/src/main/java/com/nike/cerberus/domain/IamPrincipalCredentials.java +++ b/src/main/java/com/nike/cerberus/domain/IamPrincipalCredentials.java @@ -21,6 +21,8 @@ import javax.validation.constraints.Pattern; +import java.util.Objects; + import static com.nike.cerberus.util.AwsIamRoleArnParser.AWS_IAM_PRINCIPAL_ARN_REGEX_ROLE_GENERATION; /** @@ -49,4 +51,18 @@ public String getRegion() { public void setRegion(String region) { this.region = region; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IamPrincipalCredentials that = (IamPrincipalCredentials) o; + return Objects.equals(iamPrincipalArn, that.iamPrincipalArn) && + Objects.equals(region, that.region); + } + + @Override + public int hashCode() { + return Objects.hash(iamPrincipalArn, region); + } } diff --git a/src/main/java/com/nike/cerberus/metrics/CallbackLongGauge.java b/src/main/java/com/nike/cerberus/metrics/CallbackLongGauge.java new file mode 100644 index 000000000..6f87f86d3 --- /dev/null +++ b/src/main/java/com/nike/cerberus/metrics/CallbackLongGauge.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2019 Nike, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package com.nike.cerberus.metrics; + +import com.codahale.metrics.Gauge; +import com.codahale.metrics.Metric; +import com.signalfx.codahale.metrics.MetricBuilder; + +import java.util.function.Supplier; + +public class CallbackLongGauge implements Metric, Gauge { + + private final Supplier supplier; + + public CallbackLongGauge(Supplier supplier) { + this.supplier = supplier; + } + + @Override + public Long getValue() { + return supplier.get(); + } + + public static class Builder implements MetricBuilder { + + private final Supplier supplier; + + public static Builder getInstance(Supplier supplier) { + return new Builder(supplier); + } + + private Builder(Supplier supplier) { + this.supplier = supplier; + } + + @Override + public CallbackLongGauge newMetric() { + return new CallbackLongGauge(supplier); + } + + @Override + public boolean isInstance(Metric metric) { + return metric instanceof CallbackLongGauge; + } + } +} \ No newline at end of file diff --git a/src/main/java/com/nike/cerberus/server/config/guice/CmsGuiceModule.java b/src/main/java/com/nike/cerberus/server/config/guice/CmsGuiceModule.java index 6e8c6d6d6..84e6609e0 100644 --- a/src/main/java/com/nike/cerberus/server/config/guice/CmsGuiceModule.java +++ b/src/main/java/com/nike/cerberus/server/config/guice/CmsGuiceModule.java @@ -27,11 +27,15 @@ import com.amazonaws.regions.Region; import com.amazonaws.regions.Regions; import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.benmanes.caffeine.cache.Cache; import com.google.inject.*; import com.google.inject.name.Names; import com.nike.backstopper.apierror.projectspecificinfo.ProjectApiErrors; import com.nike.cerberus.auth.connector.AuthConnector; import com.nike.cerberus.aws.KmsClientFactory; +import com.nike.cerberus.cache.MetricReportingCache; +import com.nike.cerberus.domain.IamPrincipalCredentials; +import com.nike.cerberus.domain.IamRoleAuthResponse; import com.nike.cerberus.endpoints.*; import com.nike.cerberus.endpoints.authentication.*; import com.nike.cerberus.endpoints.authentication.CodeHandlingMfaCheck; @@ -74,6 +78,7 @@ import java.util.stream.Collectors; import static com.nike.cerberus.service.EncryptionService.*; +import static com.github.benmanes.caffeine.cache.Caffeine.newBuilder; public class CmsGuiceModule extends AbstractModule { @@ -366,9 +371,19 @@ public Region currentRegion() { return currentRegion; } + @Provides + @Singleton + @Named("kmsAuthCache") + public Cache kmsAuthCache(MetricsService metricsService, + KmsAuthCachingOptionalPropertyHolder kmsAuthCachingOptionalPropertyHolder) { + return new MetricReportingCache("auth.kms", + kmsAuthCachingOptionalPropertyHolder.maxAge, + metricsService, null); + } + /** * This 'holder' class allows optional injection of KMS-data-key-caching-specific properties that are only necessary when - * SignalFx metrics reporting is enabled. + * KMS data key caching is enabled. * * The 'optional=true' parameter to Guice @Inject cannot be used in combination with the @Provides annotation * or with constructor injection. @@ -396,4 +411,19 @@ static class KmsDataKeyCachingOptionalPropertyHolder { @com.google.inject.name.Named("cms.encryption.cache.decrypt.maxAgeInSeconds") int decryptMaxAge = 0; } + + /** + * This 'holder' class allows optional injection of KMS-auth-caching-specific properties that are only necessary when + * KMS-auth caching is enabled. + * + * The 'optional=true' parameter to Guice @Inject cannot be used in combination with the @Provides annotation + * or with constructor injection. + * + * https://github.com/google/guice/wiki/FrequentlyAskedQuestions + */ + static class KmsAuthCachingOptionalPropertyHolder { + @Inject(optional=true) + @com.google.inject.name.Named("cms.iam.token.cache.maxAgeInSeconds") + int maxAge = 10; + } } diff --git a/src/main/java/com/nike/cerberus/service/AuthenticationService.java b/src/main/java/com/nike/cerberus/service/AuthenticationService.java index ebeb25ef1..18573ed30 100644 --- a/src/main/java/com/nike/cerberus/service/AuthenticationService.java +++ b/src/main/java/com/nike/cerberus/service/AuthenticationService.java @@ -27,6 +27,7 @@ import com.amazonaws.services.kms.model.KMSInvalidStateException; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.benmanes.caffeine.cache.Cache; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -98,6 +99,8 @@ public class AuthenticationService { public static final String MAX_TOKEN_REFRESH_COUNT = "cms.user.token.maxRefreshCount"; public static final String USER_TOKEN_TTL = "cms.user.token.ttl"; public static final String IAM_TOKEN_TTL = "cms.iam.token.ttl"; + public static final String CACHE_ENABLED = "cms.auth.iam.kms.cache.enabled"; + public static final String CACHE = "kmsAuthCache"; public static final String LOOKUP_SELF_POLICY = "lookup-self"; public static final int KMS_SIZE_LIMIT = 4096; @@ -116,6 +119,8 @@ public class AuthenticationService { private final String iamTokenTTL; private final AwsIamRoleService awsIamRoleService; private final int maxTokenRefreshCount; + private final boolean cacheEnabled; + private final Cache cache; @Inject(optional=true) @Named(ADMIN_IAM_ROLES_PROPERTY) @@ -138,8 +143,9 @@ public AuthenticationService(SafeDepositBoxDao safeDepositBoxDao, AuthTokenService authTokenService, @Named(USER_TOKEN_TTL) String userTokenTTL, @Named(IAM_TOKEN_TTL) String iamTokenTTL, - AwsIamRoleService awsIamRoleService) { - + AwsIamRoleService awsIamRoleService, + @Named(CACHE_ENABLED) boolean cacheEnabled, + @Named(CACHE) Cache cache) { this.safeDepositBoxDao = safeDepositBoxDao; this.awsIamRoleDao = awsIamRoleDao; this.authServiceConnector = authConnector; @@ -155,6 +161,8 @@ public AuthenticationService(SafeDepositBoxDao safeDepositBoxDao, this.userTokenTTL = userTokenTTL; this.iamTokenTTL = iamTokenTTL; this.awsIamRoleService = awsIamRoleService; + this.cacheEnabled = cacheEnabled; + this.cache = cache; } /** @@ -228,7 +236,7 @@ public IamRoleAuthResponse authenticate(IamRoleCredentials credentials) { authPrincipalMetadata.put(CerberusPrincipal.METADATA_KEY_AWS_ACCOUNT_ID, awsIamRoleArnParser.getAccountId(iamPrincipalArn)); authPrincipalMetadata.put(CerberusPrincipal.METADATA_KEY_AWS_IAM_ROLE_NAME, awsIamRoleArnParser.getRoleName(iamPrincipalArn)); - return authenticate(iamPrincipalCredentials, authPrincipalMetadata); + return cachingKmsAuthenticate(iamPrincipalCredentials, authPrincipalMetadata); } public IamRoleAuthResponse authenticate(IamPrincipalCredentials credentials) { @@ -237,7 +245,7 @@ public IamRoleAuthResponse authenticate(IamPrincipalCredentials credentials) { final Map authPrincipalMetadata = generateCommonIamPrincipalAuthMetadata(iamPrincipalArn, credentials.getRegion()); authPrincipalMetadata.put(CerberusPrincipal.METADATA_KEY_AWS_IAM_PRINCIPAL_ARN, iamPrincipalArn); - return authenticate(credentials, authPrincipalMetadata); + return cachingKmsAuthenticate(credentials, authPrincipalMetadata); } /** @@ -257,6 +265,14 @@ public AuthTokenResponse stsAuthenticate(final String iamPrincipalArn) { return authResponse; } + private IamRoleAuthResponse cachingKmsAuthenticate(IamPrincipalCredentials credentials, Map authPrincipalMetadata) { + if (cacheEnabled){ + return cache.get(credentials, key -> authenticate(credentials, authPrincipalMetadata)); + } else { + return authenticate(credentials, authPrincipalMetadata); + } + } + private IamRoleAuthResponse authenticate(IamPrincipalCredentials credentials, Map authPrincipalMetadata) { final AwsIamRoleKmsKeyRecord kmsKeyRecord; final AwsIamRoleRecord iamRoleRecord; diff --git a/src/main/java/com/nike/cerberus/service/MetricsService.java b/src/main/java/com/nike/cerberus/service/MetricsService.java index d980bc963..3f22c112f 100644 --- a/src/main/java/com/nike/cerberus/service/MetricsService.java +++ b/src/main/java/com/nike/cerberus/service/MetricsService.java @@ -18,7 +18,9 @@ package com.nike.cerberus.service; import com.codahale.metrics.Counter; +import com.codahale.metrics.Gauge; import com.codahale.metrics.Metric; +import com.nike.cerberus.metrics.CallbackLongGauge; import com.nike.riposte.metrics.codahale.CodahaleMetricsCollector; import com.nike.riposte.metrics.codahale.contrib.SignalFxReporterFactory; import com.signalfx.codahale.metrics.MetricBuilder; @@ -31,6 +33,7 @@ import javax.annotation.Nullable; import javax.inject.Inject; import java.util.Map; +import java.util.function.Supplier; public class MetricsService { @@ -97,6 +100,10 @@ public Counter getOrCreateCounter(String name, Map dimensions) { return getOrCreate(MetricBuilder.COUNTERS, name, dimensions); } + public Gauge getOrCreateLongCallbackGauge(String name, Supplier supplier, Map dimensions) { + return getOrCreate(CallbackLongGauge.Builder.getInstance(supplier), name, dimensions); + } + private M getOrCreate(MetricBuilder builder, String metricName, Map dimensions) { if (metricMetadata == null) { diff --git a/src/main/resources/cms.conf b/src/main/resources/cms.conf index f0d7fd758..92e4a4234 100644 --- a/src/main/resources/cms.conf +++ b/src/main/resources/cms.conf @@ -198,3 +198,4 @@ cms.iam.token.ttl=1h cms.user.groups.caseSensitive=true cms.encryption.cache.enabled=false +cms.auth.iam.kms.cache.enabled=false diff --git a/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java b/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java index 94e18641a..ba93a158d 100644 --- a/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java +++ b/src/test/java/com/nike/cerberus/service/AuthenticationServiceTest.java @@ -130,7 +130,9 @@ public void setup() { authTokenService, "1h", "1h", - awsIamRoleService + awsIamRoleService, + false, + null ); }