Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NA: Add missing cache for sdk auth #184

Merged
merged 2 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions apps/opik-backend/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ redis:

authentication:
enabled: ${AUTH_ENABLED:-false}
apiKeyResolutionCacheTTLInSec: ${AUTH_API_KEY_RESOLUTION_CACHE_TTL_IN_SEC:-5} #0 means no cache
sdk:
url: ${AUTH_SDK_URL:-''}
ui:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ private Mono<Long> insert(UUID datasetId, List<DatasetItem> items) {
List<List<DatasetItem>> batches = Lists.partition(items, bulkConfig.getSize());

return Flux.fromIterable(batches)
.flatMapSequential(batch -> asyncTemplate.nonTransaction(connection -> mapAndInsert(datasetId, batch, connection)))
.flatMapSequential(
batch -> asyncTemplate.nonTransaction(connection -> mapAndInsert(datasetId, batch, connection)))
.reduce(0L, Long::sum);
}

Expand Down Expand Up @@ -432,7 +433,7 @@ private Mono<Long> mapAndInsert(UUID datasetId, List<DatasetItem> items, Connect
statement.bind("input" + i, getOrDefault(item.input()));
statement.bind("expectedOutput" + i, getOrDefault(item.expectedOutput()));
statement.bind("metadata" + i, getOrDefault(item.metadata()));
statement.bind("createdBy" + i,userName);
statement.bind("createdBy" + i, userName);
statement.bind("lastUpdatedBy" + i, userName);
i++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ INSERT INTO experiment_items (
new.last_updated_by
FROM (
<items:{item |
SELECT
SELECT
:id<item.index> AS id,
:experiment_id<item.index> AS experiment_id,
:dataset_item_id<item.index> AS dataset_item_id,
Expand Down Expand Up @@ -207,7 +207,6 @@ private Mono<Long> insert(Collection<ExperimentItem> experimentItems, Connection
});
}


private Publisher<ExperimentItem> mapToExperimentItem(Result result) {
return result.map((row, rowMetadata) -> ExperimentItem.builder()
.id(row.get("id", UUID.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ public record UrlConfig(@Valid @JsonProperty @NotNull String url) {
@JsonProperty
private boolean enabled;

@Valid
@JsonProperty
private int apiKeyResolutionCacheTTLInSec;

@Valid
@JsonProperty
private UrlConfig ui;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.comet.opik.infrastructure;


import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotNull;
Expand All @@ -11,6 +10,5 @@ public class BulkOperationsConfig {

@Valid
@JsonProperty
@NotNull
private int size;
@NotNull private int size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: int primitive can't be null.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove in the next PR

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.comet.opik.infrastructure.auth;

import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.redisson.api.RListReactive;
import org.redisson.api.RedissonReactiveClient;

import java.time.Duration;
import java.util.List;
import java.util.Optional;

@Slf4j
@RequiredArgsConstructor
class AuthCredentialsCacheService implements CacheService {

public static final String KEY_FORMAT = "auth-%s-%s";
private final RedissonReactiveClient redissonClient;
private final int ttlInSeconds;

public Optional<AuthCredentials> resolveApiKeyUserAndWorkspaceIdFromCache(@NonNull String apiKey,
@NonNull String workspaceName) {
String key = KEY_FORMAT.formatted(apiKey, workspaceName);

RListReactive<String> bucket = redissonClient.getList(key);

return bucket
.readAll()
.blockOptional()
.filter(pair -> pair.size() == 2)
.map(pair -> new AuthCredentials(pair.getFirst(), pair.getLast()));
}

public void cache(@NonNull String apiKey, @NonNull String workspaceName, @NonNull String userName,
@NonNull String workspaceId) {
String key = KEY_FORMAT.formatted(apiKey, workspaceName);
redissonClient.getList(key).addAll(List.of(userName, workspaceId)).block();
redissonClient.getList(key).expire(Duration.ofSeconds(ttlInSeconds)).block();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.comet.opik.infrastructure.AuthenticationConfig;
import com.comet.opik.infrastructure.OpikConfiguration;
import com.comet.opik.infrastructure.redis.LockService;
import com.google.common.base.Preconditions;
import com.google.inject.Provides;
import jakarta.inject.Provider;
Expand All @@ -10,6 +11,7 @@
import jakarta.ws.rs.client.ClientBuilder;
import lombok.NonNull;
import org.apache.commons.lang3.StringUtils;
import org.redisson.api.RedissonReactiveClient;
import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

Expand All @@ -21,7 +23,9 @@ public class AuthModule extends DropwizardAwareModule<OpikConfiguration> {
@Singleton
public AuthService authService(
@Config("authentication") AuthenticationConfig config,
@NonNull Provider<RequestContext> requestContext) {
@NonNull Provider<RequestContext> requestContext,
@NonNull RedissonReactiveClient redissonClient,
@NonNull LockService lockService) {

if (!config.isEnabled()) {
return new AuthServiceImpl(requestContext);
Expand All @@ -37,7 +41,12 @@ public AuthService authService(
Preconditions.checkArgument(StringUtils.isNotBlank(config.getSdk().url()),
"The property authentication.sdk.url must not be blank when authentication is enabled");

return new RemoteAuthService(client(), config.getSdk(), config.getUi(), requestContext);
var cacheService = config.getApiKeyResolutionCacheTTLInSec() > 0
? new AuthCredentialsCacheService(redissonClient, config.getApiKeyResolutionCacheTTLInSec())
: new NoopCacheService();

return new RemoteAuthService(client(), config.getSdk(), config.getUi(), requestContext, cacheService,
lockService);
}

public Client client() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.comet.opik.infrastructure.auth;

import java.util.Optional;

interface CacheService {

record AuthCredentials(String userName, String workspaceId) {
}

void cache(String apiKey, String workspaceName, String userName, String workspaceId);
Optional<AuthCredentials> resolveApiKeyUserAndWorkspaceIdFromCache(String apiKey, String workspaceName);
}

class NoopCacheService implements CacheService {

@Override
public void cache(String apiKey, String workspaceName, String userName, String workspaceId) {
// no-op
}

@Override
public Optional<AuthCredentialsCacheService.AuthCredentials> resolveApiKeyUserAndWorkspaceIdFromCache(
String apiKey, String workspaceName) {
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.infrastructure.auth;

import com.comet.opik.domain.ProjectService;
import com.comet.opik.infrastructure.redis.LockService;
import jakarta.inject.Provider;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.client.Client;
Expand All @@ -13,26 +14,37 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.net.URI;
import java.util.Optional;

import static com.comet.opik.infrastructure.AuthenticationConfig.UrlConfig;
import static com.comet.opik.infrastructure.auth.AuthCredentialsCacheService.AuthCredentials;
import static com.comet.opik.infrastructure.redis.LockService.Lock;

@RequiredArgsConstructor
@Slf4j
class RemoteAuthService implements AuthService {

public static final String NOT_ALLOWED_TO_ACCESS_WORKSPACE = "User not allowed to access workspace";
private final @NonNull Client client;
private final @NonNull UrlConfig apiKeyAuthUrl;
private final @NonNull UrlConfig uiAuthUrl;
private final @NonNull Provider<RequestContext> requestContext;
private final @NonNull CacheService cacheService;
private final @NonNull LockService lockService;

record AuthRequest(String workspaceName) {
}

record AuthResponse(String user, String workspaceId) {
}

record ValidatedAuthCredentials(boolean shouldCache, String userName, String workspaceId) {
}

@Override
public void authenticate(HttpHeaders headers, Cookie sessionToken) {

Expand Down Expand Up @@ -66,24 +78,61 @@ private void authenticateUsingSessionToken(Cookie sessionToken, String workspace
.cookie(sessionToken)
.post(Entity.json(new AuthRequest(workspaceName)))) {

verifyResponse(response);
AuthResponse credentials = verifyResponse(response);

setCredentialIntoContext(credentials.user(), credentials.workspaceId());
}
}

private void authenticateUsingApiKey(HttpHeaders headers, String workspaceName) {
try (var response = client.target(URI.create(apiKeyAuthUrl.url()))
.request()
.accept(MediaType.APPLICATION_JSON)
.header(jakarta.ws.rs.core.HttpHeaders.AUTHORIZATION,
Optional.ofNullable(headers.getHeaderString(jakarta.ws.rs.core.HttpHeaders.AUTHORIZATION))
.orElse(""))
.post(Entity.json(new AuthRequest(workspaceName)))) {

verifyResponse(response);
String apiKey = Optional.ofNullable(headers.getHeaderString(HttpHeaders.AUTHORIZATION))
.orElse("");

if (apiKey.isBlank()) {
log.info("API key not found in headers");
throw new ClientErrorException(NOT_ALLOWED_TO_ACCESS_WORKSPACE, Response.Status.UNAUTHORIZED);
}

var lock = new Lock(apiKey, workspaceName);

ValidatedAuthCredentials credentials = lockService.executeWithLock(
lock,
Mono.fromCallable(() -> validateApiKeyAndGetCredentials(workspaceName, apiKey))
.subscribeOn(Schedulers.boundedElastic()))
.block();

if (credentials.shouldCache()) {
log.debug("Caching user and workspace id for API key");
cacheService.cache(apiKey, workspaceName, credentials.userName(), credentials.workspaceId());
}

setCredentialIntoContext(credentials.userName(), credentials.workspaceId());
}

private ValidatedAuthCredentials validateApiKeyAndGetCredentials(String workspaceName, String apiKey) {
Optional<AuthCredentials> credentials = cacheService.resolveApiKeyUserAndWorkspaceIdFromCache(apiKey,
workspaceName);

if (credentials.isEmpty()) {
log.debug("User and workspace id not found in cache for API key");

try (var response = client.target(URI.create(apiKeyAuthUrl.url()))
.request()
.accept(MediaType.APPLICATION_JSON)
.header(HttpHeaders.AUTHORIZATION,
apiKey)
.post(Entity.json(new AuthRequest(workspaceName)))) {

AuthResponse authResponse = verifyResponse(response);
return new ValidatedAuthCredentials(true, authResponse.user(), authResponse.workspaceId());
}
} else {
return new ValidatedAuthCredentials(false, credentials.get().userName(), credentials.get().workspaceId());
}
}

private void verifyResponse(Response response) {
private AuthResponse verifyResponse(Response response) {
if (response.getStatusInfo().getFamily() == Response.Status.Family.SUCCESSFUL) {
var authResponse = response.readEntity(AuthResponse.class);

Expand All @@ -92,12 +141,9 @@ private void verifyResponse(Response response) {
throw new ClientErrorException(Response.Status.UNAUTHORIZED);
}

requestContext.get().setUserName(authResponse.user());
requestContext.get().setWorkspaceId(authResponse.workspaceId());
return;

return authResponse;
} else if (response.getStatus() == Response.Status.UNAUTHORIZED.getStatusCode()) {
throw new ClientErrorException("User not allowed to access workspace",
throw new ClientErrorException(NOT_ALLOWED_TO_ACCESS_WORKSPACE,
Response.Status.UNAUTHORIZED);
} else if (response.getStatus() == Response.Status.FORBIDDEN.getStatusCode()) {
throw new ClientErrorException("User has bot permission to the workspace", Response.Status.FORBIDDEN);
Expand All @@ -110,4 +156,9 @@ private void verifyResponse(Response response) {
throw new ClientErrorException(Response.Status.INTERNAL_SERVER_ERROR);
}

private void setCredentialIntoContext(String userName, String workspaceId) {
requestContext.get().setUserName(userName);
requestContext.get().setWorkspaceId(workspaceId);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,18 @@

public interface LockService {

record Lock(UUID id, String name) {
record Lock(String key) {

private static final String KEY_FORMAT = "%s-%s";

public Lock(UUID id, String name) {
this(KEY_FORMAT.formatted(id, name));
}

public Lock(String id, String name) {
this(KEY_FORMAT.formatted(id, name));
}

}

<T> Mono<T> executeWithLock(Lock lock, Mono<T> action);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public <T> Mono<T> executeWithLock(Lock lock, Mono<T> action) {

RPermitExpirableSemaphoreReactive semaphore = redisClient.getPermitExpirableSemaphore(
CommonOptions
.name("%s-%s".formatted(lock.id(), lock.name()))
.name(lock.key())
.timeout(Duration.ofMillis(distributedLockConfig.getLockTimeoutMS()))
.retryInterval(Duration.ofMillis(10))
.retryAttempts(distributedLockConfig.getLockTimeoutMS() / 10));
Expand Down Expand Up @@ -56,7 +56,7 @@ private <T> Mono<T> runAction(Lock lock, Mono<T> action, String locked) {
public <T> Flux<T> executeWithLock(Lock lock, Flux<T> stream) {
RPermitExpirableSemaphoreReactive semaphore = redisClient.getPermitExpirableSemaphore(
CommonOptions
.name("%s-%s".formatted(lock.id(), lock.name()))
.name(lock.key())
.timeout(Duration.ofMillis(distributedLockConfig.getLockTimeoutMS()))
.retryInterval(Duration.ofMillis(10))
.retryAttempts(distributedLockConfig.getLockTimeoutMS() / 10));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,19 @@ public static TestDropwizardAppExtension newTestDropwizardAppExtension(
}

public static TestDropwizardAppExtension newTestDropwizardAppExtension(
String jdbcUrl, DatabaseAnalyticsFactory databaseAnalyticsFactory, WireMockRuntimeInfo runtimeInfo,
String jdbcUrl,
DatabaseAnalyticsFactory databaseAnalyticsFactory,
WireMockRuntimeInfo runtimeInfo,
String redisUrl) {
return newTestDropwizardAppExtension(jdbcUrl, databaseAnalyticsFactory, runtimeInfo, redisUrl, null);
}

public static TestDropwizardAppExtension newTestDropwizardAppExtension(
String jdbcUrl,
DatabaseAnalyticsFactory databaseAnalyticsFactory,
WireMockRuntimeInfo runtimeInfo,
String redisUrl,
Integer cacheTtlInSeconds) {

var list = new ArrayList<String>();
list.add("database.url: " + jdbcUrl);
Expand All @@ -43,6 +54,10 @@ public static TestDropwizardAppExtension newTestDropwizardAppExtension(
list.add("authentication.enabled: true");
list.add("authentication.sdk.url: " + "%s/opik/auth".formatted(runtimeInfo.getHttpsBaseUrl()));
list.add("authentication.ui.url: " + "%s/opik/auth-session".formatted(runtimeInfo.getHttpsBaseUrl()));

if (cacheTtlInSeconds != null) {
list.add("authentication.apiKeyResolutionCacheTTLInSec: " + cacheTtlInSeconds);
}
}

GuiceyConfigurationHook hook = injector -> {
Expand Down
Loading