Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPIK-547 Store and retrieve LLM provider api key #845

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apps/opik-backend/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,6 @@ metadata:

cors:
enabled: ${CORS:-false}

encryption:
key: ${OPIK_ENCRYPTION_KEY:-'GiTHubiLoVeYouAA'}
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.annotation.Nullable;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Pattern;
import lombok.Builder;

import java.time.Instant;
import java.util.List;
import java.util.UUID;

import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
// This annotation is used to specify the strategy to be used for naming of properties for the annotated type. Required so that OpenAPI schema generation uses snake_case
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import lombok.Builder;

import java.time.Instant;
import java.util.UUID;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record ProviderApiKey(
@JsonView( {
View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID id,
@JsonView({View.Public.class, View.Write.class}) @NotBlank String provider,
BorisTkachenko marked this conversation as resolved.
Show resolved Hide resolved
@JsonView({View.Write.class}) @NotBlank String apiKey,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy
) {
@Override
public String toString() {
return "ProviderApiKey{" +
"id=" + id +
", provider='" + provider + '\'' +
", createdAt=" + createdAt +
", createdBy='" + createdBy + '\'' +
", lastUpdatedAt=" + lastUpdatedAt +
", lastUpdatedBy='" + lastUpdatedBy + '\'' +
'}';
}

public static class View {
public static class Write {
}

public static class Public {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.NotBlank;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
@Getter
public class ProviderApiKeyUpdate {
@ToString.Exclude
@NotBlank
String apiKey;
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ public Response toResponse(InvalidFormatException exception) {
log.info("Deserialization exception: {}", exception.getMessage());
int endIndex = errorMessage.indexOf(": Failed to deserialize");
return Response.status(Response.Status.BAD_REQUEST)
.entity(new ErrorMessage(List.of(endIndex == -1 ? "Unable to process JSON" : errorMessage.substring(0, endIndex))))
.entity(new ErrorMessage(
List.of(endIndex == -1 ? "Unable to process JSON" : errorMessage.substring(0, endIndex))))
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package com.comet.opik.api.resources.v1.priv;

import com.codahale.metrics.annotation.Timed;
import com.comet.opik.api.ProviderApiKey;
import com.comet.opik.api.ProviderApiKeyUpdate;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.domain.ProxyService;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.fasterxml.jackson.annotation.JsonView;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.validation.Valid;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.PATCH;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.UriInfo;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

import java.util.UUID;

@Path("/v1/private/proxy")
BorisTkachenko marked this conversation as resolved.
Show resolved Hide resolved
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@Timed
@Slf4j
@RequiredArgsConstructor(onConstructor_ = @Inject)
@Tag(name = "Proxy", description = "LLM Provider Proxy")
public class ProxyResource {

private final @NonNull ProxyService proxyService;
private final @NonNull Provider<RequestContext> requestContext;

@GET
@Path("/api_key/{id}")
@Operation(operationId = "getProviderApiKeyById", summary = "Get Provider's ApiKey by id", description = "Get Provider's ApiKey by id", responses = {
@ApiResponse(responseCode = "200", description = "ProviderApiKey resource", content = @Content(schema = @Schema(implementation = ProviderApiKey.class)))})
BorisTkachenko marked this conversation as resolved.
Show resolved Hide resolved
@JsonView({ProviderApiKey.View.Public.class})
public Response getById(@PathParam("id") UUID id) {
BorisTkachenko marked this conversation as resolved.
Show resolved Hide resolved

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Getting Provider's ApiKey by id '{}' on workspace_id '{}'", id, workspaceId);

ProviderApiKey providerApiKey = proxyService.get(id);

log.info("Got Provider's ApiKey by id '{}' on workspace_id '{}'", id, workspaceId);

return Response.ok().entity(providerApiKey).build();
}

@POST
@Path("/api_key")
@Operation(operationId = "storeApiKey", summary = "Store Provider's ApiKey", description = "Store Provider's ApiKey", responses = {
@ApiResponse(responseCode = "201", description = "Created", headers = {
@Header(name = "Location", required = true, example = "${basePath}/v1/private/proxy/api_key/{apiKeyId}", schema = @Schema(implementation = String.class))}),
@ApiResponse(responseCode = "401", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "403", description = "Access forbidden", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))
})
public Response saveApiKey(
@RequestBody(content = @Content(schema = @Schema(implementation = ProviderApiKey.class))) @JsonView(ProviderApiKey.View.Write.class) @Valid ProviderApiKey providerApiKey,
@Context UriInfo uriInfo) {
String workspaceId = requestContext.get().getWorkspaceId();
log.info("Save api key for provider '{}', on workspace_id '{}'", providerApiKey.provider(), workspaceId);
var providerApiKeyId = proxyService.saveApiKey(providerApiKey).id();
log.info("Saved api key for provider '{}', on workspace_id '{}'", providerApiKey.provider(), workspaceId);

var uri = uriInfo.getAbsolutePathBuilder().path("/%s".formatted(providerApiKeyId)).build();

return Response.created(uri).build();
}

@PATCH
@Path("/api_key/{id}")
@Operation(operationId = "storeApiKey", summary = "Store Provider's ApiKey", description = "Store Provider's ApiKey", responses = {
@ApiResponse(responseCode = "204", description = "No Content"),
@ApiResponse(responseCode = "401", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "403", description = "Access forbidden", content = @Content(schema = @Schema(implementation = ErrorMessage.class))),
@ApiResponse(responseCode = "404", description = "Not found", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))
})
public Response updateApiKey(@PathParam("id") UUID id,
@RequestBody(content = @Content(schema = @Schema(implementation = ProviderApiKeyUpdate.class))) @Valid ProviderApiKeyUpdate providerApiKeyUpdate) {
String workspaceId = requestContext.get().getWorkspaceId();

log.info("Updating api key for provider with id '{}' on workspaceId '{}'", id, workspaceId);
proxyService.updateApiKey(id, providerApiKeyUpdate);
log.info("Updated api key for provider with id '{}' on workspaceId '{}'", id, workspaceId);

return Response.noContent().build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package com.comet.opik.domain;

import com.comet.opik.api.ProviderApiKey;
import com.comet.opik.infrastructure.db.UUIDArgumentFactory;
import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory;
import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper;
import org.jdbi.v3.sqlobject.customizer.Bind;
import org.jdbi.v3.sqlobject.customizer.BindMethods;
import org.jdbi.v3.sqlobject.statement.SqlQuery;
import org.jdbi.v3.sqlobject.statement.SqlUpdate;

import java.util.Optional;
import java.util.UUID;

@RegisterConstructorMapper(ProviderApiKey.class)
@RegisterArgumentFactory(UUIDArgumentFactory.class)
public interface ProviderApiKeyDAO {
BorisTkachenko marked this conversation as resolved.
Show resolved Hide resolved

@SqlUpdate("INSERT INTO provider_api_key (id, provider, workspace_id, api_key, created_by, last_updated_by) VALUES (:bean.id, :bean.provider, :workspaceId, :bean.apiKey, :bean.createdBy, :bean.lastUpdatedBy)")
void save(@Bind("workspaceId") String workspaceId,
@BindMethods("bean") ProviderApiKey providerApiKey);

@SqlUpdate("UPDATE provider_api_key SET " +
"api_key = :apiKey, " +
"last_updated_by = :lastUpdatedBy " +
"WHERE id = :id AND workspace_id = :workspaceId")
void update(@Bind("id") UUID id,
@Bind("workspaceId") String workspaceId,
@Bind("apiKey") String encryptedApiKey,
@Bind("lastUpdatedBy") String lastUpdatedBy);

@SqlQuery("SELECT * FROM provider_api_key WHERE id = :id AND workspace_id = :workspaceId")
ProviderApiKey findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId);

default Optional<ProviderApiKey> fetch(UUID id, String workspaceId) {
return Optional.ofNullable(findById(id, workspaceId));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package com.comet.opik.domain;

import com.comet.opik.api.ProviderApiKey;
import com.comet.opik.api.ProviderApiKeyUpdate;
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.infrastructure.EncryptionService;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.google.inject.ImplementedBy;
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.inject.Singleton;
import jakarta.ws.rs.NotFoundException;
import jakarta.ws.rs.core.Response;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.jdbi.v3.core.statement.UnableToExecuteStatementException;
import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate;

import java.sql.SQLIntegrityConstraintViolationException;
import java.util.List;
import java.util.UUID;

import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY;
import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE;

@ImplementedBy(ProxyServiceImpl.class)
public interface ProxyService {
BorisTkachenko marked this conversation as resolved.
Show resolved Hide resolved

ProviderApiKey get(UUID id);
ProviderApiKey saveApiKey(ProviderApiKey providerApiKey);
void updateApiKey(UUID id, ProviderApiKeyUpdate providerApiKeyUpdate);
}

@Slf4j
@Singleton
@RequiredArgsConstructor(onConstructor_ = @Inject)
class ProxyServiceImpl implements ProxyService {

private static final String PROVIDER_API_KEY_ALREADY_EXISTS = "Api key for this provider already exists";
private final @NonNull Provider<RequestContext> requestContext;
BorisTkachenko marked this conversation as resolved.
Show resolved Hide resolved
private final @NonNull IdGenerator idGenerator;
private final @NonNull TransactionTemplate template;
private final @NonNull EncryptionService encryptionService;

@Override
public ProviderApiKey get(UUID id) {
String workspaceId = requestContext.get().getWorkspaceId();

log.info("Getting provider api key with id '{}', workspaceId '{}'", id, workspaceId);

var providerApiKey = template.inTransaction(READ_ONLY, handle -> {

var repository = handle.attach(ProviderApiKeyDAO.class);

return repository.fetch(id, workspaceId).orElseThrow(this::createNotFoundError);
});
log.info("Got provider api key with id '{}', workspaceId '{}'", id, workspaceId);

return providerApiKey.toBuilder()
.apiKey(encryptionService.decrypt(providerApiKey.apiKey()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for internal usage. Api key will not be serialized, and should never be exposed outside of application. So there is no sense in additional Serializer configurations.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the other thread. Ideally, api key deserialization should be triggered on the last minute, just right before inserting it as auth header in the client call towards the provider.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This field is excluded from serialization when returned to UI as I mentioned above. I removed decryption, since it will be used in some other service which is not implemented yet.

.build();
}

@Override
public ProviderApiKey saveApiKey(@NonNull ProviderApiKey providerApiKey) {
UUID apiKeyId = idGenerator.generateId();
String userName = requestContext.get().getUserName();
String workspaceId = requestContext.get().getWorkspaceId();

var newProviderApiKey = providerApiKey.toBuilder()
.id(apiKeyId)
.apiKey(encryptionService.encrypt(providerApiKey.apiKey()))
andrescrz marked this conversation as resolved.
Show resolved Hide resolved
.createdBy(userName)
.lastUpdatedBy(userName)
.build();

try {
template.inTransaction(WRITE, handle -> {

var repository = handle.attach(ProviderApiKeyDAO.class);
repository.save(workspaceId, newProviderApiKey);

return newProviderApiKey;
});

return get(apiKeyId);
} catch (UnableToExecuteStatementException e) {
if (e.getCause() instanceof SQLIntegrityConstraintViolationException) {
throw newConflict();
} else {
throw e;
}
}
}

@Override
public void updateApiKey(@NonNull UUID id, @NonNull ProviderApiKeyUpdate providerApiKeyUpdate) {
String userName = requestContext.get().getUserName();
String workspaceId = requestContext.get().getWorkspaceId();
String encryptedApiKey = encryptionService.encrypt(providerApiKeyUpdate.getApiKey());

template.inTransaction(WRITE, handle -> {

var repository = handle.attach(ProviderApiKeyDAO.class);

ProviderApiKey providerApiKey = repository.fetch(id, workspaceId)
.orElseThrow(this::createNotFoundError);

repository.update(providerApiKey.id(),
workspaceId,
encryptedApiKey,
userName);

return null;
});
}

private EntityAlreadyExistsException newConflict() {
log.info(PROVIDER_API_KEY_ALREADY_EXISTS);
return new EntityAlreadyExistsException(new ErrorMessage(List.of(PROVIDER_API_KEY_ALREADY_EXISTS)));
}

private NotFoundException createNotFoundError() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: inherited from similar code, but this method name should be renamed in order to be reusable e.g: throwApiKeyNotFoundException or similar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left as is, since this function is used as a Supplier, and it actually just creates an exception instance and doesn't throw anything.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'd call it newNotFoundException then, but anyway, very very picky and minor.

String message = "Provider api key not found";
log.info(message);
return new NotFoundException(message,
Response.status(Response.Status.NOT_FOUND).entity(new ErrorMessage(List.of(message))).build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ private List<Span> bindSpanToProjectAndId(SpanBatch batch, List<Project> project
Project project = projectPerName.get(projectName);

if (project == null) {
log.warn("Project not found for span project '{}' and default '{}'", span.projectName(), projectName);
log.warn("Project not found for span project '{}' and default '{}'", span.projectName(),
projectName);
throw new IllegalStateException("Project not found: %s".formatted(span.projectName()));
}

Expand Down
Loading
Loading