Skip to content

Commit

Permalink
OPIK-626 Add metadata and change_description to prompt_version (#928)
Browse files Browse the repository at this point in the history
  • Loading branch information
BorisTkachenko authored Dec 19, 2024
1 parent 403dfdb commit f83916d
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
Expand Down Expand Up @@ -30,6 +31,8 @@ public record Prompt(
Prompt.View.Updatable.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description,
@JsonView({
Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template,
@JsonView({Prompt.View.Write.class}) @Nullable JsonNode metadata,
@JsonView({Prompt.View.Write.class}) @Nullable String changeDescription,
@JsonView({Prompt.View.Public.class,
Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({Prompt.View.Public.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import com.comet.opik.utils.ValidationUtils;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.annotation.Nullable;
import jakarta.validation.constraints.NotBlank;
import lombok.Builder;
import org.jdbi.v3.json.Json;

import java.time.Instant;
import java.util.List;
Expand All @@ -31,6 +33,10 @@ public record PromptVersion(
PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent. it must be 8 characters long", requiredMode = Schema.RequiredMode.NOT_REQUIRED, pattern = ValidationUtils.COMMIT_PATTERN) @CommitValidation String commit,
@JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) @NotBlank String template,
@Json @JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) @Nullable JsonNode metadata,
@JsonView({PromptVersion.View.Public.class, Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) @Nullable String changeDescription,
@JsonView({Prompt.View.Detail.class,
PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Set<String> variables,
@JsonView({Prompt.View.Detail.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ SELECT JSON_OBJECT(
'prompt_id', pv.prompt_id,
'commit', pv.commit,
'template', pv.template,
'metadata', pv.metadata,
'change_description', pv.change_description,
'created_at', pv.created_at,
'created_by', pv.created_by,
'last_updated_at', pv.last_updated_at,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,16 @@ public Prompt create(@NonNull Prompt prompt) {

if (!StringUtils.isEmpty(prompt.template())) {
EntityConstraintHandler
.handle(() -> createPromptVersionFromPromptRequest(createdPrompt, workspaceId, prompt.template()))
.handle(() -> createPromptVersionFromPromptRequest(createdPrompt, workspaceId, prompt))
.withRetry(3, this::newVersionConflict);
}

return createdPrompt;
}

private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt, String workspaceId,
String template) {
private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt,
String workspaceId,
Prompt promptPayload) {
log.info("Creating prompt version for prompt id '{}'", createdPrompt.id());

var createdVersion = transactionTemplate.inTransaction(WRITE, handle -> {
Expand All @@ -115,7 +116,9 @@ private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt,
.id(versionId)
.promptId(createdPrompt.id())
.commit(CommitUtils.getCommit(versionId))
.template(template)
.template(promptPayload.template())
.metadata(promptPayload.metadata())
.changeDescription(promptPayload.changeDescription())
.createdBy(createdPrompt.createdBy())
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
@RegisterConstructorMapper(PromptVersionId.class)
interface PromptVersionDAO {

@SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, created_by, workspace_id) " +
"VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.createdBy, :workspace_id)")
@SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, metadata, change_description, created_by, workspace_id) " +
"VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.metadata, :bean.changeDescription, :bean.createdBy, :workspace_id)")
void save(@Bind("workspace_id") String workspaceId, @BindMethods("bean") PromptVersion prompt);

@SqlQuery("SELECT * FROM prompt_versions WHERE id = :id AND workspace_id = :workspace_id")
Expand All @@ -29,7 +29,7 @@ interface PromptVersionDAO {
@SqlQuery("SELECT count(id) FROM prompt_versions WHERE prompt_id = :prompt_id AND workspace_id = :workspace_id")
long countByPromptId(@Bind("prompt_id") UUID promptId, @Bind("workspace_id") String workspaceId);

@SqlQuery("SELECT id, prompt_id, commit, template, created_at, created_by FROM prompt_versions WHERE prompt_id = :prompt_id AND workspace_id = :workspace_id ORDER BY id DESC LIMIT :limit OFFSET :offset")
@SqlQuery("SELECT * FROM prompt_versions WHERE prompt_id = :prompt_id AND workspace_id = :workspace_id ORDER BY id DESC LIMIT :limit OFFSET :offset")
List<PromptVersion> findByPromptId(@Bind("prompt_id") UUID promptId, @Bind("workspace_id") String workspaceId,
@Bind("limit") int limit, @Bind("offset") int offset);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ private PromptVersion mapObject(JsonNode jsonNode) {
.promptId(UUID.fromString(jsonNode.get("prompt_id").asText()))
.commit(jsonNode.get("commit").asText())
.template(jsonNode.get("template").asText())
.metadata(jsonNode.get("metadata"))
.changeDescription(jsonNode.get("change_description").asText())
.variables(MustacheVariableExtractor.extractVariables(jsonNode.get("template").asText()))
.createdAt(Instant.from(FORMATTER.parse(jsonNode.get("created_at").asText())))
.createdBy(jsonNode.get("created_by").asText())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
--liquibase formatted sql
--changeset BorisTkachenko:000007_add_change_description_metadata_to_prompt_version

ALTER TABLE prompt_versions ADD COLUMN change_description MEDIUMTEXT DEFAULT NULL;
ALTER TABLE prompt_versions ADD COLUMN metadata JSON DEFAULT NULL;

--rollback ALTER TABLE prompt_versions DROP COLUMN change_description;
--rollback ALTER TABLE prompt_versions DROP COLUMN metadata;
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class PromptResourceTest {
private static final TestDropwizardAppExtension app;

private static final WireMockUtils.WireMockRuntime wireMock;
private static final String[] IGNORED_FIELDS = {"latestVersion", "template"};
private static final String[] IGNORED_FIELDS = {"latestVersion", "template", "metadata", "changeDescription"};

static {
Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join();
Expand Down Expand Up @@ -429,7 +429,7 @@ void getPromptVersionsById__whenApiKeyIsPresent__thenReturnProperResponse(String
promptVersion = createPromptVersion(request, okApikey, workspaceName);

try (var actualResponse = client
.target(RESOURCE_PATH.formatted(baseURI) + "/%s/versions".formatted(promptVersion.id()))
.target(RESOURCE_PATH.formatted(baseURI) + "/%s/versions".formatted(promptVersion.promptId()))
.request()
.accept(MediaType.APPLICATION_JSON_TYPE)
.header(HttpHeaders.AUTHORIZATION, apiKey)
Expand Down Expand Up @@ -1391,6 +1391,8 @@ void when__promptHasMultipleVersions__thenReturnPromptWithLatestVersion() {

Prompt expectedPrompt = prompt.toBuilder()
.template(promptVersion.template())
.metadata(promptVersion.metadata())
.changeDescription(promptVersion.changeDescription())
.versionCount(2L)
.build();

Expand Down Expand Up @@ -1452,6 +1454,8 @@ private void assertLatestVersion(Prompt actualPrompt, Prompt expectedPrompt, Set
assertThat(promptVersion.commit())
.isEqualTo(promptVersion.id().toString().substring(promptVersion.id().toString().length() - 8));
assertThat(promptVersion.template()).isEqualTo(expectedPrompt.template());
assertThat(promptVersion.metadata()).isEqualTo(expectedPrompt.metadata());
assertThat(promptVersion.changeDescription()).isEqualTo(expectedPrompt.changeDescription());
assertThat(promptVersion.variables()).isEqualTo(expectedVariables);
assertThat(promptVersion.createdBy()).isEqualTo(USER);
assertThat(promptVersion.createdAt()).isBetween(expectedPrompt.createdAt(), Instant.now());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.podam.manufacturer;

import com.comet.opik.api.PromptVersion;
import com.fasterxml.jackson.databind.JsonNode;
import org.apache.commons.lang3.RandomStringUtils;
import uk.co.jemos.podam.api.AttributeMetadata;
import uk.co.jemos.podam.api.DataProviderStrategy;
Expand Down Expand Up @@ -40,6 +41,8 @@ public PromptVersion getType(DataProviderStrategy strategy, AttributeMetadata me
.id(id)
.commit(id.toString().substring(id.toString().length() - 8))
.template(template)
.metadata(strategy.getTypeValue(metadata, context, JsonNode.class))
.changeDescription(strategy.getTypeValue(metadata, context, String.class))
.variables(Set.of(variable1, variable2, variable3))
.promptId(strategy.getTypeValue(metadata, context, UUID.class))
.createdBy(strategy.getTypeValue(metadata, context, String.class))
Expand Down

0 comments on commit f83916d

Please sign in to comment.