diff --git a/apps/opik-backend/config.yml b/apps/opik-backend/config.yml index 61d2fb8ab2..0bbd465603 100644 --- a/apps/opik-backend/config.yml +++ b/apps/opik-backend/config.yml @@ -88,3 +88,6 @@ metadata: cors: enabled: ${CORS:-false} + +encryption: + key: ${OPIK_ENCRYPTION_KEY:-'GiTHubiLoVeYouAA'} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKey.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKey.java new file mode 100644 index 0000000000..cc3b653491 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKey.java @@ -0,0 +1,34 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; + +import java.time.Instant; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record ProviderApiKey( + @JsonView( { + View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID id, + @JsonView({View.Public.class, View.Write.class}) @NotBlank String provider, + @JsonView({View.Write.class}) @NotBlank String apiKey, + @JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, + @JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, + @JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy +) { + public static class View { + public static class Write { + } + + public static class Public { + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKeyUpdate.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKeyUpdate.java new file mode 100644 index 0000000000..ca2950b9d7 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ProviderApiKeyUpdate.java @@ -0,0 +1,16 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; +import lombok.Getter; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +@Getter +public class ProviderApiKeyUpdate { + @NotBlank String apiKey; +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProxyResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProxyResource.java new file mode 100644 index 0000000000..77c095551d --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProxyResource.java @@ -0,0 +1,106 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.codahale.metrics.annotation.Timed; +import com.comet.opik.api.ProviderApiKey; +import com.comet.opik.api.ProviderApiKeyUpdate; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.domain.ProxyService; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.fasterxml.jackson.annotation.JsonView; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.validation.Valid; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.GET; +import jakarta.ws.rs.PATCH; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +import java.util.UUID; + +@Path("/v1/private/proxy") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +@Timed +@Slf4j +@RequiredArgsConstructor(onConstructor_ = @Inject) +@Tag(name = "Proxy", description = "LLM Provider Proxy") +public class ProxyResource { + + private final @NonNull ProxyService proxyService; + private final @NonNull Provider requestContext; + + @GET + @Path("/api_key/{id}") + @Operation(operationId = "getProviderApiKeyById", summary = "Get Provider's ApiKey by id", description = "Get Provider's ApiKey by id", responses = { + @ApiResponse(responseCode = "200", description = "ProviderApiKey resource", content = @Content(schema = @Schema(implementation = ProviderApiKey.class)))}) + @JsonView({ProviderApiKey.View.Public.class}) + public Response getById(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting Provider's ApiKey by id '{}' on workspace_id '{}'", id, workspaceId); + + ProviderApiKey providerApiKey = proxyService.get(id); + + log.info("Got Provider's ApiKey by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.ok().entity(providerApiKey).build(); + } + + @POST + @Path("/api_key") + @Operation(operationId = "storeApiKey", summary = "Store Provider's ApiKey", description = "Store Provider's ApiKey", responses = { + @ApiResponse(responseCode = "201", description = "Created", headers = { + @Header(name = "Location", required = true, example = "${basePath}/v1/private/proxy/api_key/{apiKeyId}", schema = @Schema(implementation = String.class))}), + @ApiResponse(responseCode = "401", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "403", description = "Access forbidden", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) + }) + public Response saveApiKey( + @RequestBody(content = @Content(schema = @Schema(implementation = ProviderApiKey.class))) @JsonView(ProviderApiKey.View.Write.class) @Valid ProviderApiKey providerApiKey, + @Context UriInfo uriInfo) { + String workspaceId = requestContext.get().getWorkspaceId(); + log.info("Save api key for provider '{}', on workspace_id '{}'", providerApiKey.provider(), workspaceId); + var providerApiKeyId = proxyService.saveApiKey(providerApiKey).id(); + log.info("Saved api key for provider '{}', on workspace_id '{}'", providerApiKey.provider(), workspaceId); + + var uri = uriInfo.getAbsolutePathBuilder().path("/%s".formatted(providerApiKeyId)).build(); + + return Response.created(uri).build(); + } + + @PATCH + @Path("/api_key/{id}") + @Operation(operationId = "storeApiKey", summary = "Store Provider's ApiKey", description = "Store Provider's ApiKey", responses = { + @ApiResponse(responseCode = "204", description = "No Content"), + @ApiResponse(responseCode = "401", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "403", description = "Access forbidden", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not found", content = @Content(schema = @Schema(implementation = ErrorMessage.class))) + }) + public Response updateApiKey(@PathParam("id") UUID id, + @RequestBody(content = @Content(schema = @Schema(implementation = ProviderApiKeyUpdate.class))) @Valid ProviderApiKeyUpdate providerApiKeyUpdate) { + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Updating api key for provider with id '{}' on workspaceId '{}'", id, workspaceId); + proxyService.updateApiKey(id, providerApiKeyUpdate); + log.info("Updated api key for provider with id '{}' on workspaceId '{}'", id, workspaceId); + + return Response.noContent().build(); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProviderApiKeyDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProviderApiKeyDAO.java new file mode 100644 index 0000000000..3624711c75 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProviderApiKeyDAO.java @@ -0,0 +1,38 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.ProviderApiKey; +import com.comet.opik.infrastructure.db.UUIDArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.customizer.BindMethods; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +import java.util.Optional; +import java.util.UUID; + +@RegisterConstructorMapper(ProviderApiKey.class) +@RegisterArgumentFactory(UUIDArgumentFactory.class) +public interface ProviderApiKeyDAO { + + @SqlUpdate("INSERT INTO provider_api_key (id, provider, workspace_id, api_key, created_by, last_updated_by) VALUES (:bean.id, :bean.provider, :workspaceId, :bean.apiKey, :bean.createdBy, :bean.lastUpdatedBy)") + void save(@Bind("workspaceId") String workspaceId, + @BindMethods("bean") ProviderApiKey providerApiKey); + + @SqlUpdate("UPDATE provider_api_key SET " + + "api_key = :apiKey, " + + "last_updated_by = :lastUpdatedBy " + + "WHERE id = :id AND workspace_id = :workspaceId") + void update(@Bind("id") UUID id, + @Bind("workspaceId") String workspaceId, + @Bind("apiKey") String encryptedApiKey, + @Bind("lastUpdatedBy") String lastUpdatedBy); + + @SqlQuery("SELECT * FROM provider_api_key WHERE id = :id AND workspace_id = :workspaceId") + ProviderApiKey findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId); + + default Optional fetch(UUID id, String workspaceId) { + return Optional.ofNullable(findById(id, workspaceId)); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProxyService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProxyService.java new file mode 100644 index 0000000000..97c7412ea9 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProxyService.java @@ -0,0 +1,131 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.ProviderApiKey; +import com.comet.opik.api.ProviderApiKeyUpdate; +import com.comet.opik.api.error.EntityAlreadyExistsException; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.infrastructure.EncryptionService; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.google.inject.ImplementedBy; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.inject.Singleton; +import jakarta.ws.rs.NotFoundException; +import jakarta.ws.rs.core.Response; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.jdbi.v3.core.statement.UnableToExecuteStatementException; +import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; + +import java.sql.SQLIntegrityConstraintViolationException; +import java.util.List; +import java.util.UUID; + +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY; +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; + +@ImplementedBy(ProxyServiceImpl.class) +public interface ProxyService { + + ProviderApiKey get(UUID id); + ProviderApiKey saveApiKey(ProviderApiKey providerApiKey); + void updateApiKey(UUID id, ProviderApiKeyUpdate providerApiKeyUpdate); +} + +@Slf4j +@Singleton +@RequiredArgsConstructor(onConstructor_ = @Inject) +class ProxyServiceImpl implements ProxyService { + + private static final String PROVIDER_API_KEY_ALREADY_EXISTS = "Api key for this provider already exists"; + private final @NonNull Provider requestContext; + private final @NonNull IdGenerator idGenerator; + private final @NonNull TransactionTemplate template; + private final @NonNull EncryptionService encryptionService; + + @Override + public ProviderApiKey get(UUID id) { + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting provider api key with id '{}', workspaceId '{}'", id, workspaceId); + + var providerApiKey = template.inTransaction(READ_ONLY, handle -> { + + var repository = handle.attach(ProviderApiKeyDAO.class); + + return repository.fetch(id, workspaceId).orElseThrow(this::createNotFoundError); + }); + log.info("Got provider api key with id '{}', workspaceId '{}'", id, workspaceId); + + return providerApiKey.toBuilder() + .apiKey(encryptionService.decrypt(providerApiKey.apiKey())) + .build(); + } + + @Override + public ProviderApiKey saveApiKey(@NonNull ProviderApiKey providerApiKey) { + UUID apiKeyId = idGenerator.generateId(); + String userName = requestContext.get().getUserName(); + String workspaceId = requestContext.get().getWorkspaceId(); + + var newProviderApiKey = providerApiKey.toBuilder() + .id(apiKeyId) + .apiKey(encryptionService.encrypt(providerApiKey.apiKey())) + .createdBy(userName) + .lastUpdatedBy(userName) + .build(); + + try { + template.inTransaction(WRITE, handle -> { + + var repository = handle.attach(ProviderApiKeyDAO.class); + repository.save(workspaceId, newProviderApiKey); + + return newProviderApiKey; + }); + + return get(apiKeyId); + } catch (UnableToExecuteStatementException e) { + if (e.getCause() instanceof SQLIntegrityConstraintViolationException) { + throw newConflict(); + } else { + throw e; + } + } + } + + @Override + public void updateApiKey(@NonNull UUID id, @NonNull ProviderApiKeyUpdate providerApiKeyUpdate) { + String userName = requestContext.get().getUserName(); + String workspaceId = requestContext.get().getWorkspaceId(); + String encryptedApiKey = encryptionService.encrypt(providerApiKeyUpdate.getApiKey()); + + template.inTransaction(WRITE, handle -> { + + var repository = handle.attach(ProviderApiKeyDAO.class); + + ProviderApiKey providerApiKey = repository.fetch(id, workspaceId) + .orElseThrow(this::createNotFoundError); + + repository.update(providerApiKey.id(), + workspaceId, + encryptedApiKey, + userName); + + return null; + }); + } + + private EntityAlreadyExistsException newConflict() { + log.info(PROVIDER_API_KEY_ALREADY_EXISTS); + return new EntityAlreadyExistsException(new ErrorMessage(List.of(PROVIDER_API_KEY_ALREADY_EXISTS))); + } + + private NotFoundException createNotFoundError() { + String message = "Provider api key not found"; + log.info(message); + return new NotFoundException(message, + Response.status(Response.Status.NOT_FOUND).entity(new ErrorMessage(List.of(message))).build()); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionConfig.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionConfig.java new file mode 100644 index 0000000000..c593b8ec53 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionConfig.java @@ -0,0 +1,15 @@ +package com.comet.opik.infrastructure; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import lombok.Data; + +@Data +public class EncryptionConfig { + + @Valid + @JsonProperty + @NotNull + private String key; +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionService.java new file mode 100644 index 0000000000..cca8147b9c --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/EncryptionService.java @@ -0,0 +1,59 @@ +package com.comet.opik.infrastructure; + +import com.google.inject.ImplementedBy; +import jakarta.inject.Inject; +import jakarta.inject.Singleton; +import lombok.NonNull; + +import javax.crypto.Cipher; +import javax.crypto.spec.SecretKeySpec; +import java.nio.charset.StandardCharsets; +import java.security.Key; +import java.util.Base64; + +@ImplementedBy(EncryptionServiceImpl.class) +public interface EncryptionService { + String encrypt(String data); + String decrypt(String encryptedData); +} + +@Singleton +class EncryptionServiceImpl implements EncryptionService { + + private static final String ALGO = "AES"; + private final Base64.Encoder mimeEncoder = Base64.getMimeEncoder(); + private final Base64.Decoder mimeDecoder = Base64.getMimeDecoder(); + @NonNull + Key key; + + @Inject + public EncryptionServiceImpl(@NonNull OpikConfiguration config) { + byte[] keyBytes = config.getEncryption().getKey().getBytes(StandardCharsets.UTF_8); + key = new SecretKeySpec(keyBytes, ALGO); + } + + @Override + public String encrypt(String data) { + try { + Cipher c = Cipher.getInstance(ALGO); + c.init(Cipher.ENCRYPT_MODE, key); + byte[] encVal = c.doFinal(data.getBytes()); + return mimeEncoder.encodeToString(encVal); + } catch (Exception ex) { + throw new RuntimeException("Failed to encrypt. " + ex.getMessage(), ex); + } + } + + @Override + public String decrypt(String encryptedData) { + try { + Cipher c = Cipher.getInstance(ALGO); + c.init(Cipher.DECRYPT_MODE, key); + byte[] decordedValue = mimeDecoder.decode(encryptedData); + byte[] decValue = c.doFinal(decordedValue); + return new String(decValue); + } catch (Exception ex) { + throw new RuntimeException("Failed to decrypt. " + ex.getMessage(), ex); + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java index 0c3452fb3c..3cb785edeb 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/OpikConfiguration.java @@ -53,4 +53,8 @@ public class OpikConfiguration extends JobConfiguration { @Valid @NotNull @JsonProperty private BatchOperationsConfig batchOperations = new BatchOperationsConfig(); + + @Valid + @NotNull @JsonProperty + private EncryptionConfig encryption = new EncryptionConfig(); } diff --git a/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000006_add_provider_api_key_table.sql b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000006_add_provider_api_key_table.sql new file mode 100644 index 0000000000..8e1c4ce69c --- /dev/null +++ b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000006_add_provider_api_key_table.sql @@ -0,0 +1,17 @@ +--liquibase formatted sql +--changeset BorisTkachenko:000006_add_provider_api_key_table + +CREATE TABLE IF NOT EXISTS provider_api_key ( + id CHAR(36) NOT NULL, + provider VARCHAR(250) NOT NULL, + workspace_id VARCHAR(150) NOT NULL, + api_key VARCHAR(250) NOT NULL, + created_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + created_by VARCHAR(100) NOT NULL DEFAULT 'admin', + last_updated_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + last_updated_by VARCHAR(100) NOT NULL DEFAULT 'admin', + CONSTRAINT `provider_api_key_pk` PRIMARY KEY (id), + CONSTRAINT `provider_api_key_workspace_id_provider` UNIQUE (workspace_id, provider) + ); + +--rollback DROP TABLE IF EXISTS provider_api_key; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProxyResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProxyResourceTest.java new file mode 100644 index 0000000000..454267fde7 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProxyResourceTest.java @@ -0,0 +1,219 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.comet.opik.api.ProviderApiKey; +import com.comet.opik.api.ProviderApiKeyUpdate; +import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; +import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.MigrationUtils; +import com.comet.opik.api.resources.utils.MySQLContainerUtils; +import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; +import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.domain.ProviderApiKeyDAO; +import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; +import com.comet.opik.infrastructure.EncryptionService; +import com.comet.opik.podam.PodamFactoryUtils; +import com.redis.testcontainers.RedisContainer; +import jakarta.ws.rs.HttpMethod; +import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.clickhouse.ClickHouseContainer; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.lifecycle.Startables; +import ru.vyarus.dropwizard.guice.test.ClientSupport; +import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; +import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; +import uk.co.jemos.podam.api.PodamFactory; + +import java.sql.SQLException; +import java.util.UUID; + +import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; +import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; +import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY; +import static org.assertj.core.api.Assertions.assertThat; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@DisplayName("Proxy Resource Test") +class ProxyResourceTest { + public static final String URL_TEMPLATE = "%s/v1/private/proxy"; + + public static final String[] IGNORED_FIELDS = {"createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt"}; + + private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); + private static final ClickHouseContainer CLICKHOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer(); + private static final MySQLContainer MYSQL = MySQLContainerUtils.newMySQLContainer(); + + private static final String USER = UUID.randomUUID().toString(); + + @RegisterExtension + private static final TestDropwizardAppExtension app; + + private static final WireMockUtils.WireMockRuntime wireMock; + + static { + Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join(); + + wireMock = WireMockUtils.startWireMock(); + + DatabaseAnalyticsFactory databaseAnalyticsFactory = ClickHouseContainerUtils + .newDatabaseAnalyticsFactory(CLICKHOUSE_CONTAINER, DATABASE_NAME); + + app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + MYSQL.getJdbcUrl(), databaseAnalyticsFactory, wireMock.runtimeInfo(), REDIS.getRedisURI()); + } + + private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); + + private String baseURI; + private ClientSupport client; + private EncryptionService encryptionService; + private TransactionTemplate mySqlTemplate; + + @BeforeAll + void setUpAll(ClientSupport client, Jdbi jdbi, EncryptionService encryptionService, TransactionTemplate mySqlTemplate) throws SQLException { + + MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); + + try (var connection = CLICKHOUSE_CONTAINER.createConnection("")) { + MigrationUtils.runDbMigration(connection, CLICKHOUSE_CHANGELOG_FILE, + ClickHouseContainerUtils.migrationParameters()); + } + + this.baseURI = "http://localhost:%d".formatted(client.getPort()); + this.client = client; + this.encryptionService = encryptionService; + this.mySqlTemplate = mySqlTemplate; + + ClientSupportUtils.config(client); + } + + private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { + AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); + } + + @AfterAll + void tearDownAll() { + wireMock.server().stop(); + } + + @Test + @DisplayName("Create and update provider Api Key") + void createAndUpdateProviderApiKey() { + + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String provider = factory.manufacturePojo(String.class); + String providerApiKey = factory.manufacturePojo(String.class); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var id = createProviderApiKey(provider, providerApiKey, apiKey, workspaceName, 201); + var expectedProviderApiKey = ProviderApiKey.builder().id(id).provider(provider).build(); + getAndAssertProviderApiKey(expectedProviderApiKey, apiKey, workspaceName); + checkEncryption(id, workspaceId, providerApiKey); + + String newProviderApiKey = factory.manufacturePojo(String.class); + updateProviderApiKey(id, newProviderApiKey, apiKey, workspaceName, 204); + checkEncryption(id, workspaceId, newProviderApiKey); + } + + @Test + @DisplayName("Create provider Api Key for existing provider should fail") + void createProviderApiKeyForExistingProviderShouldFail() { + + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String provider = factory.manufacturePojo(String.class); + String providerApiKey = factory.manufacturePojo(String.class); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + createProviderApiKey(provider, providerApiKey, apiKey, workspaceName, 201); + createProviderApiKey(provider, providerApiKey, apiKey, workspaceName, 409); + } + + @Test + @DisplayName("Update provider Api Key for non-existing Id") + void updateProviderFail() { + + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + String providerApiKey = factory.manufacturePojo(String.class); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + // for non-existing id + updateProviderApiKey(UUID.randomUUID(), providerApiKey, apiKey, workspaceName, 404); + } + + private UUID createProviderApiKey(String provider, String providerApiKey, String apiKey, String workspaceName, int expectedStatus) { + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("api_key") + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(ProviderApiKey.builder().provider(provider).apiKey(providerApiKey).build()))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(expectedStatus); + if (expectedStatus == 201) { + return TestUtils.getIdFromLocation(actualResponse.getLocation()); + } + + return null; + } + } + + private void updateProviderApiKey(UUID id, String providerApiKey, String apiKey, String workspaceName, int expectedStatus) { + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("api_key/" + id.toString()) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .method(HttpMethod.PATCH, Entity.json(ProviderApiKeyUpdate.builder().apiKey(providerApiKey).build()))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(expectedStatus); + } + } + + private void getAndAssertProviderApiKey(ProviderApiKey expected, String apiKey, String workspaceName) { + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("api_key/" + expected.id().toString()) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + assertThat(actualResponse.hasEntity()).isTrue(); + + var actualEntity = actualResponse.readEntity(ProviderApiKey.class); + assertThat(actualEntity.provider()).isEqualTo(expected.provider()); + assertThat(actualEntity.apiKey()).isBlank(); + } + } + + private void checkEncryption(UUID id, String workspaceId, String expectedApiKey) { + String actualEncryptedApiKey = mySqlTemplate.inTransaction(READ_ONLY, handle -> { + var repository = handle.attach(ProviderApiKeyDAO.class); + return repository.findById(id, workspaceId).apiKey(); + }); + assertThat(encryptionService.decrypt(actualEncryptedApiKey)).isEqualTo(expectedApiKey); + } +} \ No newline at end of file diff --git a/apps/opik-backend/src/test/resources/config-test.yml b/apps/opik-backend/src/test/resources/config-test.yml index 5e8f1590c8..5f284ebe48 100644 --- a/apps/opik-backend/src/test/resources/config-test.yml +++ b/apps/opik-backend/src/test/resources/config-test.yml @@ -82,3 +82,6 @@ metadata: cors: enabled: ${CORS:-false} + +encryption: + key: ${OPIK_ENCRYPTION_KEY:-'GiTHubiLoVeYouAA'}