Skip to content

Commit

Permalink
OPIK-71: Add metadata to Experiment (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrescrz authored Sep 11, 2024
1 parent 77996fe commit 5c338ed
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 40 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
Expand All @@ -21,6 +22,7 @@ public record Experiment(
@JsonView({Experiment.View.Public.class, Experiment.View.Write.class}) @NotBlank String datasetName,
@JsonView({Experiment.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID datasetId,
@JsonView({Experiment.View.Public.class, Experiment.View.Write.class}) @NotBlank String name,
@JsonView({Experiment.View.Public.class, Experiment.View.Write.class}) JsonNode metadata,
@JsonView({
Experiment.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List<FeedbackScoreAverage> feedbackScores,
@JsonView({Experiment.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Long traceCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public Dataset findById(@NonNull UUID id, @NonNull String workspaceId) {
log.info("Finding dataset with id '{}', workspaceId '{}'", id, workspaceId);
return template.inTransaction(READ_ONLY, handle -> {
var dao = handle.attach(DatasetDAO.class);
var dataset = dao.findById(id, workspaceId).orElseThrow(this::newNotFoundException);
var dataset = dao.findById(id, workspaceId).orElseThrow(this::newNotFoundException);
log.info("Found dataset with id '{}', workspaceId '{}'", id, workspaceId);
return dataset;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.comet.opik.api.Experiment;
import com.comet.opik.api.ExperimentSearchCriteria;
import com.comet.opik.api.FeedbackScoreAverage;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.databind.JsonNode;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Result;
Expand Down Expand Up @@ -45,6 +47,7 @@ INSERT INTO experiments (
dataset_id,
name,
workspace_id,
metadata,
created_by,
last_updated_by
)
Expand All @@ -57,6 +60,7 @@ INSERT INTO experiments (
new.dataset_id,
new.name,
new.workspace_id,
new.metadata,
new.created_by,
new.last_updated_by
FROM (
Expand All @@ -65,6 +69,7 @@ INSERT INTO experiments (
:dataset_id AS dataset_id,
:name AS name,
:workspace_id AS workspace_id,
:metadata AS metadata,
:created_by AS created_by,
:last_updated_by AS last_updated_by
) AS new
Expand All @@ -86,6 +91,7 @@ LEFT JOIN (
e.dataset_id as dataset_id,
e.id as id,
e.name as name,
e.metadata as metadata,
e.created_at as created_at,
e.last_updated_at as last_updated_at,
e.created_by as created_by,
Expand Down Expand Up @@ -188,6 +194,7 @@ INNER JOIN (
e.dataset_id,
e.id,
e.name,
e.metadata as metadata,
e.created_at,
e.last_updated_at,
e.created_by,
Expand All @@ -202,6 +209,7 @@ INNER JOIN (
e.dataset_id as dataset_id,
e.id as id,
e.name as name,
e.metadata as metadata,
e.created_at as created_at,
e.last_updated_at as last_updated_at,
e.created_by as created_by,
Expand Down Expand Up @@ -305,6 +313,7 @@ INNER JOIN (
e.dataset_id,
e.id,
e.name,
e.metadata as metadata,
e.created_at,
e.last_updated_at,
e.created_by,
Expand Down Expand Up @@ -351,22 +360,22 @@ private Publisher<? extends Result> insert(Experiment experiment, Connection con
var statement = connection.createStatement(INSERT)
.bind("id", experiment.id())
.bind("dataset_id", experiment.datasetId())
.bind("name", experiment.name());

.bind("name", experiment.name())
.bind("metadata", getOrDefault(experiment.metadata()));
return makeFluxContextAware((userName, workspaceName, workspaceId) -> {

log.info("Inserting experiment with id '{}', datasetId '{}', datasetName '{}', workspaceId '{}'",
experiment.id(), experiment.datasetId(), experiment.datasetName(), workspaceId);

statement
.bind("created_by", userName)
statement.bind("created_by", userName)
.bind("last_updated_by", userName)
.bind("workspace_id", workspaceId);

return Flux.from(statement.execute());
});
}

private String getOrDefault(JsonNode jsonNode) {
return Optional.ofNullable(jsonNode).map(JsonNode::toString).orElse("");
}

Mono<Experiment> getById(@NonNull UUID id) {
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> getById(id, connection))
Expand All @@ -379,7 +388,6 @@ private Publisher<? extends Result> getById(UUID id, Connection connection) {
var statement = connection.createStatement(SELECT_BY_ID)
.bind("id", id)
.bind("entity_type", FeedbackScoreDAO.EntityType.TRACE.getType());

return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
}

Expand All @@ -388,6 +396,7 @@ private Publisher<Experiment> mapToDto(Result result) {
.id(row.get("id", UUID.class))
.datasetId(row.get("dataset_id", UUID.class))
.name(row.get("name", String.class))
.metadata(getOrDefault(row.get("metadata", String.class)))
.createdAt(row.get("created_at", Instant.class))
.lastUpdatedAt(row.get("last_updated_at", Instant.class))
.createdBy(row.get("created_by", String.class))
Expand All @@ -397,6 +406,13 @@ private Publisher<Experiment> mapToDto(Result result) {
.build());
}

private JsonNode getOrDefault(String field) {
return Optional.ofNullable(field)
.filter(s -> !s.isBlank())
.map(JsonUtils::getJsonNodeFromString)
.orElse(null);
}

private static List<FeedbackScoreAverage> getFeedbackScores(Row row) {
List<FeedbackScoreAverage> feedbackScoresAvg = Arrays
.stream(Optional.ofNullable(row.get("feedback_scores", List[].class))
Expand All @@ -406,12 +422,11 @@ private static List<FeedbackScoreAverage> getFeedbackScores(Row row) {
.map(scores -> new FeedbackScoreAverage(scores.getFirst().toString(),
new BigDecimal(scores.get(1).toString())))
.toList();

return feedbackScoresAvg.isEmpty() ? null : feedbackScoresAvg;
}

Mono<Experiment.ExperimentPage> find(int page, int size,
@NonNull ExperimentSearchCriteria experimentSearchCriteria) {
Mono<Experiment.ExperimentPage> find(
int page, int size, @NonNull ExperimentSearchCriteria experimentSearchCriteria) {
return countTotal(experimentSearchCriteria).flatMap(total -> find(page, size, experimentSearchCriteria, total));
}

Expand All @@ -432,7 +447,6 @@ private Publisher<? extends Result> find(
.bind("limit", size)
.bind("offset", (page - 1) * size);
bindSearchCriteria(statement, experimentSearchCriteria, false);

return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
}

Expand All @@ -443,13 +457,12 @@ private Mono<Long> countTotal(ExperimentSearchCriteria experimentSearchCriteria)
.reduce(0L, Long::sum);
}

private Publisher<? extends Result> countTotal(ExperimentSearchCriteria experimentSearchCriteria,
Connection connection) {
private Publisher<? extends Result> countTotal(
ExperimentSearchCriteria experimentSearchCriteria, Connection connection) {
log.info("Counting experiments by '{}'", experimentSearchCriteria);
var template = newFindTemplate(FIND_COUNT, experimentSearchCriteria);
var statement = connection.createStatement(template.render());
bindSearchCriteria(statement, experimentSearchCriteria, true);

return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
}

Expand All @@ -473,11 +486,9 @@ private void bindSearchCriteria(Statement statement, ExperimentSearchCriteria cr
}

public Flux<WorkspaceAndResourceId> getExperimentWorkspaces(@NonNull Set<UUID> experimentIds) {

if (experimentIds.isEmpty()) {
return Flux.empty();
}

return Mono.from(connectionFactory.create())
.flatMapMany(connection -> {
var statement = connection.createStatement(FIND_EXPERIMENT_AND_WORKSPACE_BY_DATASET_IDS);
Expand All @@ -488,5 +499,4 @@ public Flux<WorkspaceAndResourceId> getExperimentWorkspaces(@NonNull Set<UUID> e
row.get("workspace_id", String.class),
row.get("id", UUID.class))));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
--liquibase formatted sql
--changeset andrescrz:add_metadata_to_experiments

ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.experiments
ADD COLUMN metadata String DEFAULT '';

--rollback ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.experiments DROP COLUMN metadata;
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ class ExperimentsResourceTest {
private static final String[] EXPERIMENT_IGNORED_FIELDS = new String[]{
"id", "datasetId", "feedbackScores", "traceCount", "createdAt", "lastUpdatedAt", "createdBy",
"lastUpdatedBy"};

public static final String[] IGNORED_FIELDS = {"input", "output", "feedbackScores", "createdAt", "lastUpdatedAt",
"createdBy", "lastUpdatedBy"};

Expand All @@ -102,9 +101,7 @@ class ExperimentsResourceTest {
private static final TimeBasedEpochGenerator GENERATOR = Generators.timeBasedEpochGenerator();

private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer();

private static final MySQLContainer<?> MY_SQL_CONTAINER = MySQLContainerUtils.newMySQLContainer();

private static final ClickHouseContainer CLICK_HOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer();

@RegisterExtension
Expand Down Expand Up @@ -686,6 +683,7 @@ void findByDatasetIdAndName() {
.map(experiment -> experiment.toBuilder()
.datasetName(datasetName)
.name(name)
.metadata(null)
.build())
.toList();
experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment,
Expand Down Expand Up @@ -1248,10 +1246,11 @@ void createAndGetFeedbackAvg() {
}

@Test
void createWithoutIdAndGet() {
void createWithoutOptionalFieldsAndGet() {
var expectedExperiment = podamFactory.manufacturePojo(Experiment.class)
.toBuilder()
.id(null)
.metadata(null)
.build();
var expectedId = createAndAssert(expectedExperiment, API_KEY, TEST_WORKSPACE);

Expand Down

0 comments on commit 5c338ed

Please sign in to comment.