Skip to content

Commit

Permalink
[OPIK-616] Add fields to LLM Provider Api Key endpoints (#939)
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisTkachenko authored Dec 20, 2024
1 parent fe65e5a commit 7cae567
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;
import lombok.Builder;

import java.time.Instant;
Expand All @@ -22,8 +23,9 @@ public record ProviderApiKey(
@JsonView( {
View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID id,
@JsonView({View.Public.class, View.Write.class}) @NotNull LlmProvider provider,
@JsonView({
@JsonView({View.Public.class,
View.Write.class}) @NotBlank @JsonDeserialize(using = ProviderApiKeyDeserializer.class) String apiKey,
@JsonView({View.Public.class, View.Write.class}) @Size(max = 150) String name,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt,
Expand All @@ -49,6 +51,7 @@ public static class Public {
}
}

@Builder(toBuilder = true)
public record ProviderApiKeyPage(
@JsonView( {
Project.View.Public.class}) int page,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.Size;
import lombok.Builder;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record ProviderApiKeyUpdate(
@NotBlank @JsonDeserialize(using = ProviderApiKeyDeserializer.class) String apiKey) {
@NotBlank @JsonDeserialize(using = ProviderApiKeyDeserializer.class) String apiKey,
@Size(max = 150) String name) {

@Override
public String toString() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@

import java.util.UUID;

import static com.comet.opik.infrastructure.EncryptionUtils.decrypt;
import static com.comet.opik.infrastructure.EncryptionUtils.maskApiKey;

@Path("/v1/private/llm-provider-key")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
Expand All @@ -60,10 +63,21 @@ public Response find() {
String workspaceId = requestContext.get().getWorkspaceId();

log.info("Find LLM Provider's ApiKeys for workspaceId '{}'", workspaceId);
Page<ProviderApiKey> providerApiKeyPage = llmProviderApiKeyService.find(workspaceId);
ProviderApiKey.ProviderApiKeyPage providerApiKeyPage = llmProviderApiKeyService.find(workspaceId);
log.info("Found LLM Provider's ApiKeys for workspaceId '{}'", workspaceId);

return Response.ok().entity(providerApiKeyPage).build();
return Response.ok().entity(
providerApiKeyPage.toBuilder()
.content(
providerApiKeyPage.content().stream()
.map(providerApiKey -> providerApiKey.toBuilder()
.apiKey(maskApiKey(decrypt(providerApiKey.apiKey())))
.build())
.toList()
)
.build()
)
.build();
}

@GET
Expand All @@ -82,7 +96,9 @@ public Response getById(@PathParam("id") UUID id) {

log.info("Got LLM Provider's ApiKey by id '{}' on workspace_id '{}'", id, workspaceId);

return Response.ok().entity(providerApiKey).build();
return Response.ok().entity(providerApiKey.toBuilder()
.apiKey(maskApiKey(decrypt(providerApiKey.apiKey())))
.build()).build();
}

@POST
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.domain;

import com.comet.opik.api.ProviderApiKey;
import com.comet.opik.api.ProviderApiKeyUpdate;
import com.comet.opik.infrastructure.db.UUIDArgumentFactory;
import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory;
import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper;
Expand All @@ -19,18 +20,18 @@
@RegisterArgumentFactory(UUIDArgumentFactory.class)
public interface LlmProviderApiKeyDAO {

@SqlUpdate("INSERT INTO llm_provider_api_key (id, provider, workspace_id, api_key, created_by, last_updated_by) VALUES (:bean.id, :bean.provider, :workspaceId, :bean.apiKey, :bean.createdBy, :bean.lastUpdatedBy)")
@SqlUpdate("INSERT INTO llm_provider_api_key (id, provider, workspace_id, api_key, name, created_by, last_updated_by) " +
"VALUES (:bean.id, :bean.provider, :workspaceId, :bean.apiKey, :bean.name, :bean.createdBy, :bean.lastUpdatedBy)")
void save(@Bind("workspaceId") String workspaceId,
@BindMethods("bean") ProviderApiKey providerApiKey);

@SqlUpdate("UPDATE llm_provider_api_key SET " +
"api_key = :apiKey, " +
"last_updated_by = :lastUpdatedBy " +
"api_key = :bean.apiKey, name = :bean.name, last_updated_by = :lastUpdatedBy " +
"WHERE id = :id AND workspace_id = :workspaceId")
void update(@Bind("id") UUID id,
@Bind("workspaceId") String workspaceId,
@Bind("apiKey") String encryptedApiKey,
@Bind("lastUpdatedBy") String lastUpdatedBy);
@Bind("lastUpdatedBy") String lastUpdatedBy,
@BindMethods("bean") ProviderApiKeyUpdate providerApiKeyUpdate);

@SqlQuery("SELECT * FROM llm_provider_api_key WHERE id = :id AND workspace_id = :workspaceId")
ProviderApiKey findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public interface LlmProviderApiKeyService {

ProviderApiKey find(UUID id, String workspaceId);

Page<ProviderApiKey> find(String workspaceId);
ProviderApiKey.ProviderApiKeyPage find(String workspaceId);

ProviderApiKey saveApiKey(ProviderApiKey providerApiKey, String userName, String workspaceId);

Expand Down Expand Up @@ -58,12 +58,11 @@ public ProviderApiKey find(@NonNull UUID id, @NonNull String workspaceId) {
return repository.fetch(id, workspaceId).orElseThrow(this::createNotFoundError);
});

return providerApiKey.toBuilder()
.build();
return providerApiKey;
}

@Override
public Page<ProviderApiKey> find(@NonNull String workspaceId) {
public ProviderApiKey.ProviderApiKeyPage find(@NonNull String workspaceId) {
List<ProviderApiKey> providerApiKeys = template.inTransaction(READ_ONLY, handle -> {
var repository = handle.attach(LlmProviderApiKeyDAO.class);
return repository.find(workspaceId);
Expand Down Expand Up @@ -122,8 +121,8 @@ public void updateApiKey(@NonNull UUID id, @NonNull ProviderApiKeyUpdate provide

repository.update(providerApiKey.id(),
workspaceId,
providerApiKeyUpdate.apiKey(),
userName);
userName,
providerApiKeyUpdate);

return null;
});
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.infrastructure;

import lombok.NonNull;
import org.apache.commons.lang3.StringUtils;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
Expand Down Expand Up @@ -50,4 +51,11 @@ public static String decrypt(@NonNull String encryptedData) {
throw new SecurityException("Failed to decrypt. " + ex.getMessage(), ex);
}
}

public static String maskApiKey(@NonNull String apiKey) {
return apiKey.length() <= 12
? StringUtils.repeat('*', apiKey.length())
: apiKey.substring(0, 3) + StringUtils.repeat('*', apiKey.length() - 6)
+ apiKey.substring(apiKey.length() - 3);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--liquibase formatted sql
--changeset BorisTkachenko:000008_add_name_to_provider_api_key

ALTER TABLE llm_provider_api_key ADD COLUMN name VARCHAR(150) DEFAULT NULL;

--rollback ALTER TABLE llm_provider_api_key DROP COLUMN name;
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import ru.vyarus.dropwizard.guice.test.ClientSupport;
import uk.co.jemos.podam.api.PodamUtils;

import java.util.Set;
import java.util.UUID;
Expand All @@ -31,23 +30,17 @@ public LlmProviderApiKeyResourceClient(ClientSupport client) {
}

public ProviderApiKey createProviderApiKey(
String providerApiKey, String apiKey, String workspaceName, int expectedStatus) {
return createProviderApiKey(providerApiKey, randomLlmProvider(), apiKey, workspaceName, expectedStatus);
}

public ProviderApiKey createProviderApiKey(
String providerApiKey, LlmProvider llmProvider, String apiKey, String workspaceName, int expectedStatus) {
ProviderApiKey body = ProviderApiKey.builder().provider(llmProvider).apiKey(providerApiKey).build();
ProviderApiKey providerApiKey, String apiKey, String workspaceName, int expectedStatus) {
try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI))
.request()
.accept(MediaType.APPLICATION_JSON_TYPE)
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(WORKSPACE_HEADER, workspaceName)
.post(Entity.json(body))) {
.post(Entity.json(providerApiKey))) {

assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(expectedStatus);
if (expectedStatus == 201) {
return body.toBuilder()
return providerApiKey.toBuilder()
.id(TestUtils.getIdFromLocation(actualResponse.getLocation()))
.build();
}
Expand All @@ -56,15 +49,23 @@ public ProviderApiKey createProviderApiKey(
}
}

public void updateProviderApiKey(UUID id, String providerApiKey, String apiKey, String workspaceName,
public ProviderApiKey createProviderApiKey(
String providerApiKey, LlmProvider llmProvider, String apiKey, String workspaceName, int expectedStatus) {
ProviderApiKey body = ProviderApiKey.builder().provider(llmProvider).apiKey(providerApiKey).build();

return createProviderApiKey(body, apiKey, workspaceName, expectedStatus);
}

public void updateProviderApiKey(UUID id, ProviderApiKeyUpdate providerApiKeyUpdate, String apiKey,
String workspaceName,
int expectedStatus) {
try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI))
.path(id.toString())
.request()
.accept(MediaType.APPLICATION_JSON_TYPE)
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(WORKSPACE_HEADER, workspaceName)
.method(HttpMethod.PATCH, Entity.json(ProviderApiKeyUpdate.builder().apiKey(providerApiKey).build()))) {
.method(HttpMethod.PATCH, Entity.json(providerApiKeyUpdate))) {

assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(expectedStatus);
}
Expand Down Expand Up @@ -95,10 +96,7 @@ public ProviderApiKey getById(UUID id, String workspaceName, String apiKey, int
if (expectedStatus == 200) {
assertThat(actualResponse.hasEntity()).isTrue();

var actualEntity = actualResponse.readEntity(ProviderApiKey.class);
assertThat(actualEntity.apiKey()).isBlank();

return actualEntity;
return actualResponse.readEntity(ProviderApiKey.class);
}

return null;
Expand All @@ -116,13 +114,8 @@ public Page<ProviderApiKey> getAll(String workspaceName, String apiKey) {
assertThat(actualResponse.hasEntity()).isTrue();

var actualEntity = actualResponse.readEntity(ProviderApiKey.ProviderApiKeyPage.class);
actualEntity.content().forEach(providerApiKey -> assertThat(providerApiKey.apiKey()).isBlank());

return actualEntity;
}
}

public LlmProvider randomLlmProvider() {
return LlmProvider.values()[PodamUtils.getIntegerInRange(0, LlmProvider.values().length - 1)];
}
}
Loading

0 comments on commit 7cae567

Please sign in to comment.