diff --git a/apps/opik-backend/config.yml b/apps/opik-backend/config.yml index 61d2fb8ab2..0bbd465603 100644 --- a/apps/opik-backend/config.yml +++ b/apps/opik-backend/config.yml @@ -88,3 +88,6 @@ metadata: cors: enabled: ${CORS:-false} + +encryption: + key: ${OPIK_ENCRYPTION_KEY:-'GiTHubiLoVeYouAA'} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java index 396610671e..1ad8300518 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/OpikApplication.java @@ -2,6 +2,7 @@ import com.comet.opik.api.error.JsonInvalidFormatExceptionMapper; import com.comet.opik.infrastructure.ConfigurationModule; +import com.comet.opik.infrastructure.EncryptionUtilsModule; import com.comet.opik.infrastructure.OpikConfiguration; import com.comet.opik.infrastructure.auth.AuthModule; import com.comet.opik.infrastructure.bi.BiModule; @@ -69,7 +70,7 @@ public void initialize(Bootstrap bootstrap) { .withPlugins(new SqlObjectPlugin(), new Jackson2Plugin())) .modules(new DatabaseAnalyticsModule(), new IdGeneratorModule(), new AuthModule(), new RedisModule(), new RateLimitModule(), new NameGeneratorModule(), new HttpModule(), new EventModule(), - new ConfigurationModule(), new BiModule()) + new ConfigurationModule(), new BiModule(), new EncryptionUtilsModule()) .installers(JobGuiceyInstaller.class) .listen(new OpikGuiceyLifecycleEventListener()) .enableAutoConfig() diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java b/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java new file mode 100644 index 0000000000..234433e39b --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/LlmProvider.java @@ -0,0 +1,13 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +@Getter +@RequiredArgsConstructor +public enum LlmProvider { + + @JsonProperty("openai") + OPEN_AI; +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java index 4cb407a237..5e8b13133b 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java @@ -7,15 +7,12 @@ import io.swagger.v3.oas.annotations.media.Schema; import jakarta.annotation.Nullable; import jakarta.validation.constraints.NotBlank; -import jakarta.validation.constraints.Pattern; import lombok.Builder; import java.time.Instant; import java.util.List; import java.util.UUID; -import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK; - @Builder(toBuilder = true) @JsonIgnoreProperties(ignoreUnknown = true) // This annotation is used to specify the strategy to be used for naming of properties for the annotated type. Required so that OpenAPI schema generation uses snake_case diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKey.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKey.java new file mode 100644 index 0000000000..187ea94d25 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKey.java @@ -0,0 +1,49 @@ +package com.comet.opik.api; + +import com.comet.opik.utils.ProviderApiKeyDeserializer; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; +import lombok.NonNull; + +import java.time.Instant; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record ProviderApiKey( + @JsonView( { + View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID id, + @JsonView({View.Public.class, View.Write.class}) @NonNull LlmProvider provider, + @JsonView({ + View.Write.class}) @NotBlank @JsonDeserialize(using = ProviderApiKeyDeserializer.class) String apiKey, + @JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, + @JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, + @JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){ + @Override + public String toString() { + return "ProviderApiKey{" + + "id=" + id + + ", provider='" + provider + '\'' + + ", createdAt=" + createdAt + + ", createdBy='" + createdBy + '\'' + + ", lastUpdatedAt=" + lastUpdatedAt + + ", lastUpdatedBy='" + lastUpdatedBy + '\'' + + '}'; + } + + public static class View { + public static class Write { + } + + public static class Public { + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKeyUpdate.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKeyUpdate.java new file mode 100644 index 0000000000..2591e4743a --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKeyUpdate.java @@ -0,0 +1,22 @@ +package com.comet.opik.api; + +import com.comet.opik.utils.ProviderApiKeyDeserializer; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; +import lombok.Getter; +import lombok.ToString; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +@Getter +public class ProviderApiKeyUpdate { + @ToString.Exclude + @NotBlank + @JsonDeserialize(using = ProviderApiKeyDeserializer.class) + String apiKey; +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/error/JsonInvalidFormatExceptionMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/api/error/JsonInvalidFormatExceptionMapper.java index f60bf5e36c..64c556c207 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/error/JsonInvalidFormatExceptionMapper.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/error/JsonInvalidFormatExceptionMapper.java @@ -16,7 +16,8 @@ public Response toResponse(InvalidFormatException exception) { log.info("Deserialization exception: {}", exception.getMessage()); int endIndex = errorMessage.indexOf(": Failed to deserialize"); return Response.status(Response.Status.BAD_REQUEST) - .entity(new ErrorMessage(List.of(endIndex == -1 ? "Unable to process JSON" : errorMessage.substring(0, endIndex)))) + .entity(new ErrorMessage( + List.of(endIndex == -1 ? "Unable to process JSON" : errorMessage.substring(0, endIndex)))) .build(); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/LlmProviderApiKeyResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/LlmProviderApiKeyResource.java new file mode 100644 index 0000000000..0f88f18481 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/LlmProviderApiKeyResource.java @@ -0,0 +1,108 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.codahale.metrics.annotation.Timed; +import com.comet.opik.api.ProviderApiKey; +import com.comet.opik.api.ProviderApiKeyUpdate; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.domain.LlmProviderApiKeyService; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.fasterxml.jackson.annotation.JsonView; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.validation.Valid; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.PATCH; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import java.util.UUID; + +@Path("/v1/private/llm-provider-key") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +@Timed +@Slf4j +@RequiredArgsConstructor(onConstructor_ = @Inject) +@Tag(name = "LlmProviderKey", description = "LLM Provider Key") +public class LlmProviderApiKeyResource { + + private final @NonNull LlmProviderApiKeyService llmProviderApiKeyService; + private final @NonNull Provider requestContext; + + @GET + @Path("{id}") + @Operation(operationId = "getLlmProviderApiKeyById", summary = "Get LLM Provider's ApiKey by id", description = "Get LLM Provider's ApiKey by id", responses = { + @ApiResponse(responseCode = "200", description = "ProviderApiKey resource", content = @Content(schema = @Schema(implementation = ProviderApiKey.class))), + @ApiResponse(responseCode = "404", description = "Not found", content = @Content(schema = @Schema(implementation = ErrorMessage.class)))}) + @JsonView({ProviderApiKey.View.Public.class}) + public Response getById(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting Provider's ApiKey by id '{}' on workspace_id '{}'", id, workspaceId); + + ProviderApiKey providerApiKey = llmProviderApiKeyService.get(id, workspaceId); + + log.info("Got Provider's ApiKey by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.ok().entity(providerApiKey).build(); + } + + @POST + @Operation(operationId = "storeLlmProviderApiKey", summary = "Store LLM Provider's ApiKey", description = "Store LLM Provider's ApiKey", responses = { + @ApiResponse(responseCode = "201", description = "Created", headers = { + @Header(name = "Location", required = true, example = "${basePath}/v1/private/proxy/api_key/{apiKeyId}", schema = @Schema(implementation = String.class))}), + @ApiResponse(responseCode = "401", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "403", description = "Access forbidden", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) + }) + public Response saveApiKey( + @RequestBody(content = @Content(schema = @Schema(implementation = ProviderApiKey.class))) @JsonView(ProviderApiKey.View.Write.class) @Valid ProviderApiKey providerApiKey, + @Context UriInfo uriInfo) { + String workspaceId = requestContext.get().getWorkspaceId(); + String userName = requestContext.get().getUserName(); + log.info("Save api key for provider '{}', on workspace_id '{}'", providerApiKey.provider(), workspaceId); + var providerApiKeyId = llmProviderApiKeyService.saveApiKey(providerApiKey, userName, workspaceId).id(); + log.info("Saved api key for provider '{}', on workspace_id '{}'", providerApiKey.provider(), workspaceId); + + var uri = uriInfo.getAbsolutePathBuilder().path("/%s".formatted(providerApiKeyId)).build(); + + return Response.created(uri).build(); + } + + @PATCH + @Path("{id}") + @Operation(operationId = "updateLlmProviderApiKey", summary = "Update LLM Provider's ApiKey", description = "Update LLM Provider's ApiKey", responses = { + @ApiResponse(responseCode = "204", description = "No Content"), + @ApiResponse(responseCode = "401", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "403", description = "Access forbidden", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not found", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) + }) + public Response updateApiKey(@PathParam("id") UUID id, + @RequestBody(content = @Content(schema = @Schema(implementation = ProviderApiKeyUpdate.class))) @Valid ProviderApiKeyUpdate providerApiKeyUpdate) { + String workspaceId = requestContext.get().getWorkspaceId(); + String userName = requestContext.get().getUserName(); + + log.info("Updating api key for provider with id '{}' on workspaceId '{}'", id, workspaceId); + llmProviderApiKeyService.updateApiKey(id, providerApiKeyUpdate, userName, workspaceId); + log.info("Updated api key for provider with id '{}' on workspaceId '{}'", id, workspaceId); + + return Response.noContent().build(); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/LlmProviderApiKeyDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/LlmProviderApiKeyDAO.java new file mode 100644 index 0000000000..eb03d0967a --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/LlmProviderApiKeyDAO.java @@ -0,0 +1,38 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.ProviderApiKey; +import com.comet.opik.infrastructure.db.UUIDArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.customizer.BindMethods; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +import java.util.Optional; +import java.util.UUID; + +@RegisterConstructorMapper(ProviderApiKey.class) +@RegisterArgumentFactory(UUIDArgumentFactory.class) +public interface LlmProviderApiKeyDAO { + + @SqlUpdate("INSERT INTO llm_provider_api_key (id, provider, workspace_id, api_key, created_by, last_updated_by) VALUES (:bean.id, :bean.provider, :workspaceId, :bean.apiKey, :bean.createdBy, :bean.lastUpdatedBy)") + void save(@Bind("workspaceId") String workspaceId, + @BindMethods("bean") ProviderApiKey providerApiKey); + + @SqlUpdate("UPDATE llm_provider_api_key SET " + + "api_key = :apiKey, " + + "last_updated_by = :lastUpdatedBy " + + "WHERE id = :id AND workspace_id = :workspaceId") + void update(@Bind("id") UUID id, + @Bind("workspaceId") String workspaceId, + @Bind("apiKey") String encryptedApiKey, + @Bind("lastUpdatedBy") String lastUpdatedBy); + + @SqlQuery("SELECT * FROM llm_provider_api_key WHERE id = :id AND workspace_id = :workspaceId") + ProviderApiKey findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId); + + default Optional fetch(UUID id, String workspaceId) { + return Optional.ofNullable(findById(id, workspaceId)); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/LlmProviderApiKeyService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/LlmProviderApiKeyService.java new file mode 100644 index 0000000000..9996c49122 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/LlmProviderApiKeyService.java @@ -0,0 +1,118 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.ProviderApiKey; +import com.comet.opik.api.ProviderApiKeyUpdate; +import com.comet.opik.api.error.EntityAlreadyExistsException; +import com.comet.opik.api.error.ErrorMessage; +import com.google.inject.ImplementedBy; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; +import jakarta.ws.rs.NotFoundException; +import jakarta.ws.rs.core.Response; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.jdbi.v3.core.statement.UnableToExecuteStatementException; +import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; + +import java.sql.SQLIntegrityConstraintViolationException; +import java.util.List; +import java.util.UUID; + +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY; +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; + +@ImplementedBy(LlmProviderApiKeyServiceImpl.class) +public interface LlmProviderApiKeyService { + + ProviderApiKey get(UUID id, String workspaceId); + ProviderApiKey saveApiKey(ProviderApiKey providerApiKey, String userName, String workspaceId); + void updateApiKey(UUID id, ProviderApiKeyUpdate providerApiKeyUpdate, String userName, String workspaceId); +} + +@Slf4j +@Singleton +@RequiredArgsConstructor(onConstructor_ = @Inject) +class LlmProviderApiKeyServiceImpl implements LlmProviderApiKeyService { + + private static final String PROVIDER_API_KEY_ALREADY_EXISTS = "Api key for this provider already exists"; + private final @NonNull IdGenerator idGenerator; + private final @NonNull TransactionTemplate template; + + @Override + public ProviderApiKey get(UUID id, String workspaceId) { + log.info("Getting provider api key with id '{}', workspaceId '{}'", id, workspaceId); + + ProviderApiKey providerApiKey = template.inTransaction(READ_ONLY, handle -> { + + var repository = handle.attach(LlmProviderApiKeyDAO.class); + + return repository.fetch(id, workspaceId).orElseThrow(this::createNotFoundError); + }); + log.info("Got provider api key with id '{}', workspaceId '{}'", id, workspaceId); + + return providerApiKey.toBuilder() + .build(); + } + + @Override + public ProviderApiKey saveApiKey(@NonNull ProviderApiKey providerApiKey, String userName, String workspaceId) { + UUID apiKeyId = idGenerator.generateId(); + + var newProviderApiKey = providerApiKey.toBuilder() + .id(apiKeyId) + .createdBy(userName) + .lastUpdatedBy(userName) + .build(); + + try { + template.inTransaction(WRITE, handle -> { + + var repository = handle.attach(LlmProviderApiKeyDAO.class); + repository.save(workspaceId, newProviderApiKey); + + return newProviderApiKey; + }); + + return get(apiKeyId, workspaceId); + } catch (UnableToExecuteStatementException e) { + if (e.getCause() instanceof SQLIntegrityConstraintViolationException) { + throw newConflict(); + } else { + throw e; + } + } + } + + @Override + public void updateApiKey(@NonNull UUID id, @NonNull ProviderApiKeyUpdate providerApiKeyUpdate, String userName, + String workspaceId) { + + template.inTransaction(WRITE, handle -> { + + var repository = handle.attach(LlmProviderApiKeyDAO.class); + + ProviderApiKey providerApiKey = repository.fetch(id, workspaceId) + .orElseThrow(this::createNotFoundError); + + repository.update(providerApiKey.id(), + workspaceId, + providerApiKeyUpdate.getApiKey(), + userName); + + return null; + }); + } + + private EntityAlreadyExistsException newConflict() { + log.info(PROVIDER_API_KEY_ALREADY_EXISTS); + return new EntityAlreadyExistsException(new ErrorMessage(List.of(PROVIDER_API_KEY_ALREADY_EXISTS))); + } + + private NotFoundException createNotFoundError() { + String message = "Provider api key not found"; + log.info(message); + return new NotFoundException(message, + Response.status(Response.Status.NOT_FOUND).entity(new ErrorMessage(List.of(message))).build()); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java index 9809125274..d05356f20f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java @@ -292,7 +292,8 @@ private List bindSpanToProjectAndId(SpanBatch batch, List project Project project = projectPerName.get(projectName); if (project == null) { - log.warn("Project not found for span project '{}' and default '{}'", span.projectName(), projectName); + log.warn("Project not found for span project '{}' and default '{}'", span.projectName(), + projectName); throw new IllegalStateException("Project not found: %s".formatted(span.projectName())); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionConfig.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionConfig.java new file mode 100644 index 0000000000..d7aebfb016 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionConfig.java @@ -0,0 +1,16 @@ +package com.comet.opik.infrastructure; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import lombok.Data; +import lombok.ToString; + +@Data +public class EncryptionConfig { + + @Valid + @JsonProperty + @NotNull @ToString.Exclude + private String key; +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionUtils.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionUtils.java new file mode 100644 index 0000000000..839ad16127 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionUtils.java @@ -0,0 +1,66 @@ +package com.comet.opik.infrastructure; + +import jakarta.inject.Inject; + +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.spec.SecretKeySpec; + +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; + +public class EncryptionUtils { + + private static void init() { + if (key != null) return; + synchronized (EncryptionUtils.class) { + if (key == null) { + byte[] keyBytes = config.getEncryption().getKey().getBytes(StandardCharsets.UTF_8); + key = new SecretKeySpec(keyBytes, ALGO); + } + } + } + + private static final String ALGO = "AES"; + private static final Base64.Encoder mimeEncoder = Base64.getMimeEncoder(); + private static final Base64.Decoder mimeDecoder = Base64.getMimeDecoder(); + private static OpikConfiguration config; + private static Key key; + + @Inject + static void setConfig(OpikConfiguration config) { + EncryptionUtils.config = config; + } + + public static String encrypt(String data) { + init(); + try { + Cipher c = Cipher.getInstance(ALGO); + c.init(Cipher.ENCRYPT_MODE, key); + byte[] encVal = c.doFinal(data.getBytes()); + return mimeEncoder.encodeToString(encVal); + } catch (NoSuchPaddingException | NoSuchAlgorithmException | InvalidKeyException | IllegalBlockSizeException + | BadPaddingException ex) { + throw new SecurityException("Failed to encrypt. " + ex.getMessage(), ex); + } + } + + public static String decrypt(String encryptedData) { + init(); + try { + Cipher c = Cipher.getInstance(ALGO); + c.init(Cipher.DECRYPT_MODE, key); + byte[] decordedValue = mimeDecoder.decode(encryptedData); + byte[] decValue = c.doFinal(decordedValue); + return new String(decValue); + } catch (BadPaddingException | NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException + | IllegalBlockSizeException ex) { + throw new SecurityException("Failed to decrypt. " + ex.getMessage(), ex); + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionUtilsModule.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionUtilsModule.java new file mode 100644 index 0000000000..e1861f031a --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionUtilsModule.java @@ -0,0 +1,11 @@ +package com.comet.opik.infrastructure; + +import ru.vyarus.dropwizard.guice.module.support.DropwizardAwareModule; + +public class EncryptionUtilsModule extends DropwizardAwareModule { + + @Override + protected void configure() { + requestStaticInjection(EncryptionUtils.class); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java index 0c3452fb3c..b15b3fddb2 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java @@ -6,6 +6,7 @@ import jakarta.validation.Valid; import jakarta.validation.constraints.NotNull; import lombok.Getter; +import lombok.ToString; @Getter public class OpikConfiguration extends JobConfiguration { @@ -53,4 +54,9 @@ public class OpikConfiguration extends JobConfiguration { @Valid @NotNull @JsonProperty private BatchOperationsConfig batchOperations = new BatchOperationsConfig(); + + @Valid + @NotNull @JsonProperty + @ToString.Exclude + private EncryptionConfig encryption = new EncryptionConfig(); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/utils/ProviderApiKeyDeserializer.java b/apps/opik-backend/src/main/java/com/comet/opik/utils/ProviderApiKeyDeserializer.java new file mode 100644 index 0000000000..69bba004de --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/utils/ProviderApiKeyDeserializer.java @@ -0,0 +1,16 @@ +package com.comet.opik.utils; + +import com.comet.opik.infrastructure.EncryptionUtils; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; + +import java.io.IOException; + +public class ProviderApiKeyDeserializer extends JsonDeserializer { + + @Override + public String deserialize(JsonParser parser, DeserializationContext ctx) throws IOException { + return EncryptionUtils.encrypt(parser.getText()); + } +} \ No newline at end of file diff --git a/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000006_add_provider_api_key_table.sql b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000006_add_provider_api_key_table.sql new file mode 100644 index 0000000000..97d7cc7b51 --- /dev/null +++ b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000006_add_provider_api_key_table.sql @@ -0,0 +1,17 @@ +--liquibase formatted sql +--changeset BorisTkachenko:000006_add_llm_provider_api_key_table + +CREATE TABLE IF NOT EXISTS llm_provider_api_key ( + id CHAR(36) NOT NULL, + provider VARCHAR(250) NOT NULL, + workspace_id VARCHAR(150) NOT NULL, + api_key VARCHAR(250) NOT NULL, + created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + created_by VARCHAR(100) NOT NULL DEFAULT 'admin', + last_updated_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + last_updated_by VARCHAR(100) NOT NULL DEFAULT 'admin', + CONSTRAINT `llm_provider_api_key_pk` PRIMARY KEY (id), + CONSTRAINT `llm_provider_api_key_workspace_id_provider` UNIQUE (workspace_id, provider) + ); + +--rollback DROP TABLE IF EXISTS llm_provider_api_key; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/LlmProviderApiKeyResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/LlmProviderApiKeyResourceTest.java new file mode 100644 index 0000000000..dd10524f36 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/LlmProviderApiKeyResourceTest.java @@ -0,0 +1,221 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.comet.opik.api.LlmProvider; +import com.comet.opik.api.ProviderApiKey; +import com.comet.opik.api.ProviderApiKeyUpdate; +import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; +import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.MigrationUtils; +import com.comet.opik.api.resources.utils.MySQLContainerUtils; +import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; +import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.domain.LlmProviderApiKeyDAO; +import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; +import com.comet.opik.infrastructure.EncryptionUtils; +import com.comet.opik.podam.PodamFactoryUtils; +import com.redis.testcontainers.RedisContainer; +import jakarta.ws.rs.HttpMethod; +import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.clickhouse.ClickHouseContainer; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.lifecycle.Startables; +import ru.vyarus.dropwizard.guice.test.ClientSupport; +import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; +import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; +import uk.co.jemos.podam.api.PodamFactory; + +import java.sql.SQLException; +import java.util.UUID; + +import static com.comet.opik.api.LlmProvider.OPEN_AI; +import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; +import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; +import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY; +import static org.assertj.core.api.Assertions.assertThat; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@DisplayName("Proxy Resource Test") +class LlmProviderApiKeyResourceTest { + public static final String URL_TEMPLATE = "%s/v1/private/llm-provider-key"; + + public static final String[] IGNORED_FIELDS = {"createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt"}; + + private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); + private static final ClickHouseContainer CLICKHOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer(); + private static final MySQLContainer MYSQL = MySQLContainerUtils.newMySQLContainer(); + + private static final String USER = UUID.randomUUID().toString(); + + @RegisterExtension + private static final TestDropwizardAppExtension app; + + private static final WireMockUtils.WireMockRuntime wireMock; + + static { + Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join(); + + wireMock = WireMockUtils.startWireMock(); + + DatabaseAnalyticsFactory databaseAnalyticsFactory = ClickHouseContainerUtils + .newDatabaseAnalyticsFactory(CLICKHOUSE_CONTAINER, DATABASE_NAME); + + app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + MYSQL.getJdbcUrl(), databaseAnalyticsFactory, wireMock.runtimeInfo(), REDIS.getRedisURI()); + } + + private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); + + private String baseURI; + private ClientSupport client; + private TransactionTemplate mySqlTemplate; + + @BeforeAll + void setUpAll(ClientSupport client, Jdbi jdbi, + TransactionTemplate mySqlTemplate) throws SQLException { + + MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); + + try (var connection = CLICKHOUSE_CONTAINER.createConnection("")) { + MigrationUtils.runDbMigration(connection, CLICKHOUSE_CHANGELOG_FILE, + ClickHouseContainerUtils.migrationParameters()); + } + + this.baseURI = "http://localhost:%d".formatted(client.getPort()); + this.client = client; + this.mySqlTemplate = mySqlTemplate; + + ClientSupportUtils.config(client); + } + + private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { + AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); + } + + @AfterAll + void tearDownAll() { + wireMock.server().stop(); + } + + @Test + @DisplayName("Create and update provider Api Key") + void createAndUpdateProviderApiKey() { + + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + var provider = OPEN_AI; + String providerApiKey = factory.manufacturePojo(String.class); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var id = createProviderApiKey(provider, providerApiKey, apiKey, workspaceName, 201); + var expectedProviderApiKey = ProviderApiKey.builder().id(id).provider(provider).build(); + getAndAssertProviderApiKey(expectedProviderApiKey, apiKey, workspaceName); + checkEncryption(id, workspaceId, providerApiKey); + + String newProviderApiKey = factory.manufacturePojo(String.class); + updateProviderApiKey(id, newProviderApiKey, apiKey, workspaceName, 204); + checkEncryption(id, workspaceId, newProviderApiKey); + } + + @Test + @DisplayName("Create provider Api Key for existing provider should fail") + void createProviderApiKeyForExistingProviderShouldFail() { + + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + var provider = OPEN_AI; + String providerApiKey = factory.manufacturePojo(String.class); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + createProviderApiKey(provider, providerApiKey, apiKey, workspaceName, 201); + createProviderApiKey(provider, providerApiKey, apiKey, workspaceName, 409); + } + + @Test + @DisplayName("Update provider Api Key for non-existing Id") + void updateProviderFail() { + + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String providerApiKey = factory.manufacturePojo(String.class); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + // for non-existing id + updateProviderApiKey(UUID.randomUUID(), providerApiKey, apiKey, workspaceName, 404); + } + + private UUID createProviderApiKey(LlmProvider provider, String providerApiKey, String apiKey, String workspaceName, + int expectedStatus) { + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(ProviderApiKey.builder().provider(provider).apiKey(providerApiKey).build()))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(expectedStatus); + if (expectedStatus == 201) { + return TestUtils.getIdFromLocation(actualResponse.getLocation()); + } + + return null; + } + } + + private void updateProviderApiKey(UUID id, String providerApiKey, String apiKey, String workspaceName, + int expectedStatus) { + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path(id.toString()) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .method(HttpMethod.PATCH, Entity.json(ProviderApiKeyUpdate.builder().apiKey(providerApiKey).build()))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(expectedStatus); + } + } + + private void getAndAssertProviderApiKey(ProviderApiKey expected, String apiKey, String workspaceName) { + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path(expected.id().toString()) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + assertThat(actualResponse.hasEntity()).isTrue(); + + var actualEntity = actualResponse.readEntity(ProviderApiKey.class); + assertThat(actualEntity.provider()).isEqualTo(expected.provider()); + assertThat(actualEntity.apiKey()).isBlank(); + } + } + + private void checkEncryption(UUID id, String workspaceId, String expectedApiKey) { + String actualEncryptedApiKey = mySqlTemplate.inTransaction(READ_ONLY, handle -> { + var repository = handle.attach(LlmProviderApiKeyDAO.class); + return repository.findById(id, workspaceId).apiKey(); + }); + assertThat(EncryptionUtils.decrypt(actualEncryptedApiKey)).isEqualTo(expectedApiKey); + } +} \ No newline at end of file diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansResourceTest.java index b6406892a9..d8f1028076 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansResourceTest.java @@ -1146,7 +1146,8 @@ void getByProjectName__whenFilterTotalEstimatedCostEqual_NotEqual__thenReturnSpa .operator(operator) .value("0") .build()); - getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey); + getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, + apiKey); } static Stream getByProjectName__whenFilterByCorrespondingField__thenReturnSpansFiltered() { @@ -1185,7 +1186,8 @@ void getByProjectName__whenFilterNameEqual_NotEqual__thenReturnSpansFiltered(Ope .operator(operator) .value(spans.getFirst().name().toUpperCase()) .build()); - getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey); + getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, + apiKey); } private Stream equalAndNotEqualFilters() { @@ -1364,7 +1366,8 @@ void getByProjectName__whenFilterStartTimeEqual_NotEqual__thenReturnSpansFiltere .operator(operator) .value(spans.getFirst().startTime().toString()) .build()); - getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey); + getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, + apiKey); } @Test @@ -1650,7 +1653,8 @@ void getByProjectName__whenFilterMetadataEqualString__thenReturnSpansFiltered(Op .key("$.model[0].version") .value("OPENAI, CHAT-GPT 4.0") .build()); - getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey); + getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, + apiKey); } @Test @@ -2540,7 +2544,8 @@ void getByProjectName__whenFilterFeedbackScoresEqual_NotEqual__thenReturnSpansFi .key(spans.getFirst().feedbackScores().get(2).name().toUpperCase()) .value(spans.getFirst().feedbackScores().get(2).value().toString()) .build()); - getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, apiKey); + getAndAssertPage(workspaceName, projectName, filters, spans, expectedSpans.reversed(), unexpectedSpans, + apiKey); } private Stream getByProjectName__whenFilterFeedbackScoresEqual_NotEqual__thenReturnSpansFiltered() { diff --git a/apps/opik-backend/src/test/resources/config-test.yml b/apps/opik-backend/src/test/resources/config-test.yml index 5e8f1590c8..5f284ebe48 100644 --- a/apps/opik-backend/src/test/resources/config-test.yml +++ b/apps/opik-backend/src/test/resources/config-test.yml @@ -82,3 +82,6 @@ metadata: cors: enabled: ${CORS:-false} + +encryption: + key: ${OPIK_ENCRYPTION_KEY:-'GiTHubiLoVeYouAA'}