Skip to content

Commit

Permalink
OPIK-71 Resolve dataset name in experiment endpoints (#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrescrz authored Sep 11, 2024
1 parent daf1fa0 commit 3163f2f
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
public record Experiment(
@JsonView( {
Experiment.View.Public.class, Experiment.View.Write.class}) UUID id,
@JsonView({Experiment.View.Write.class}) @NotBlank String datasetName,
@JsonView({Experiment.View.Public.class, Experiment.View.Write.class}) @NotBlank String datasetName,
@JsonView({Experiment.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID datasetId,
@JsonView({Experiment.View.Public.class, Experiment.View.Write.class}) @NotBlank String name,
@JsonView({
Expand All @@ -31,7 +31,7 @@ public record Experiment(
@JsonView({
Experiment.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){

@Builder
@Builder(toBuilder = true)
public record ExperimentPage(
@JsonView(Experiment.View.Public.class) int page,
@JsonView(Experiment.View.Public.class) int size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper;
import org.jdbi.v3.sqlobject.customizer.AllowUnusedBindings;
import org.jdbi.v3.sqlobject.customizer.Bind;
import org.jdbi.v3.sqlobject.customizer.BindList;
import org.jdbi.v3.sqlobject.customizer.BindMethods;
import org.jdbi.v3.sqlobject.customizer.Define;
import org.jdbi.v3.sqlobject.statement.SqlQuery;
Expand All @@ -17,6 +18,7 @@

import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;

@RegisterColumnMapper(InstantColumnMapper.class)
Expand All @@ -43,6 +45,9 @@ int update(@Bind("workspace_id") String workspaceId,
@SqlQuery("SELECT * FROM datasets WHERE id = :id AND workspace_id = :workspace_id")
Optional<Dataset> findById(@Bind("id") UUID id, @Bind("workspace_id") String workspaceId);

@SqlQuery("SELECT * FROM datasets WHERE id IN (<ids>) AND workspace_id = :workspace_id")
List<Dataset> findByIds(@BindList("ids") Set<UUID> ids, @Bind("workspace_id") String workspaceId);

@SqlUpdate("DELETE FROM datasets WHERE id = :id AND workspace_id = :workspace_id")
void delete(@Bind("id") UUID id, @Bind("workspace_id") String workspaceId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.sql.SQLIntegrityConstraintViolationException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;

Expand All @@ -45,6 +46,8 @@ public interface DatasetService {

Dataset findById(UUID id, String workspaceId);

List<Dataset> findByIds(Set<UUID> ids, String workspaceId);

Dataset findByName(String workspaceId, String name);

void delete(DatasetIdentifier identifier);
Expand Down Expand Up @@ -139,7 +142,7 @@ public void update(@NonNull UUID id, @NonNull DatasetUpdate dataset) {
int result = dao.update(workspaceId, id, dataset, userName);

if (result == 0) {
throw createNotFoundError();
throw newNotFoundException();
}
} catch (UnableToExecuteStatementException e) {
if (e.getCause() instanceof SQLIntegrityConstraintViolationException) {
Expand Down Expand Up @@ -174,10 +177,27 @@ public Dataset findById(@NonNull UUID id) {

@Override
public Dataset findById(@NonNull UUID id, @NonNull String workspaceId) {
log.info("Finding dataset with id '{}', workspaceId '{}'", id, workspaceId);
return template.inTransaction(READ_ONLY, handle -> {
var dao = handle.attach(DatasetDAO.class);
var dataset = dao.findById(id, workspaceId).orElseThrow(this::newNotFoundException);
log.info("Found dataset with id '{}', workspaceId '{}'", id, workspaceId);
return dataset;
});
}

return dao.findById(id, workspaceId).orElseThrow(this::createNotFoundError);
@Override
public List<Dataset> findByIds(@NonNull Set<UUID> ids, @NonNull String workspaceId) {
if (ids.isEmpty()) {
log.info("Returning empty datasets for empty ids, workspaceId '{}'", workspaceId);
return List.of();
}
log.info("Finding datasets with ids '{}', workspaceId '{}'", ids, workspaceId);
return template.inTransaction(READ_ONLY, handle -> {
var dao = handle.attach(DatasetDAO.class);
var datasets = dao.findByIds(ids, workspaceId);
log.info("Found datasets with ids '{}', workspaceId '{}'", ids, workspaceId);
return datasets;
});
}

Expand All @@ -186,7 +206,7 @@ public Dataset findByName(@NonNull String workspaceId, @NonNull String name) {
return template.inTransaction(READ_ONLY, handle -> {
var dao = handle.attach(DatasetDAO.class);

Dataset dataset = dao.findByName(workspaceId, name).orElseThrow(this::createNotFoundError);
Dataset dataset = dao.findByName(workspaceId, name).orElseThrow(this::newNotFoundException);

log.info("Found dataset with name '{}', id '{}', workspaceId '{}'", name, dataset.id(), workspaceId);
return dataset;
Expand All @@ -204,7 +224,7 @@ public void delete(@NonNull DatasetIdentifier identifier) {
});
}

private NotFoundException createNotFoundError() {
private NotFoundException newNotFoundException() {
String message = "Dataset not found";
return new NotFoundException(message,
Response.status(Response.Status.NOT_FOUND).entity(new ErrorMessage(List.of(message))).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;

@Singleton
@RequiredArgsConstructor(onConstructor = @__(@Inject))
Expand All @@ -32,12 +35,39 @@ public class ExperimentService {
public Mono<Experiment.ExperimentPage> find(
int page, int size, @NonNull ExperimentSearchCriteria experimentSearchCriteria) {
log.info("Finding experiments by '{}', page '{}', size '{}'", experimentSearchCriteria, page, size);
return experimentDAO.find(page, size, experimentSearchCriteria);
return experimentDAO.find(page, size, experimentSearchCriteria)
.flatMap(experimentPage -> Mono.deferContextual(ctx -> {
String workspaceId = ctx.get(RequestContext.WORKSPACE_ID);
var ids = experimentPage.content().stream()
.map(Experiment::datasetId)
.collect(Collectors.toUnmodifiableSet());
return Mono.fromCallable(() -> datasetService.findByIds(ids, workspaceId))
.subscribeOn(Schedulers.boundedElastic())
.map(datasets -> datasets.stream()
.collect(Collectors.toMap(Dataset::id, Function.identity())))
.map(datasetMap -> experimentPage.toBuilder()
.content(experimentPage.content().stream()
.map(experiment -> experiment.toBuilder()
.datasetName(Optional
.ofNullable(datasetMap.get(experiment.datasetId()))
.map(Dataset::name)
.orElse(null))
.build())
.toList())
.build());
}));
}

public Mono<Experiment> getById(@NonNull UUID id) {
log.info("Getting experiment by id '{}'", id);
return experimentDAO.getById(id).switchIfEmpty(Mono.defer(() -> Mono.error(newNotFoundException(id))));
return experimentDAO.getById(id)
.switchIfEmpty(Mono.defer(() -> Mono.error(newNotFoundException(id))))
.flatMap(experiment -> Mono.deferContextual(ctx -> {
String workspaceId = ctx.get(RequestContext.WORKSPACE_ID);
return Mono.fromCallable(() -> datasetService.findById(experiment.datasetId(), workspaceId))
.subscribeOn(Schedulers.boundedElastic())
.map(dataset -> experiment.toBuilder().datasetName(dataset.name()).build());
}));
}

public Mono<Experiment> create(@NonNull Experiment experiment) {
Expand Down Expand Up @@ -103,5 +133,4 @@ public Mono<Boolean> validateExperimentWorkspace(@NonNull String workspaceId, @N
return experimentDAO.getExperimentWorkspaces(experimentIds)
.all(experimentWorkspace -> workspaceId.equals(experimentWorkspace.workspaceId()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ExperimentsResourceTest {
private static final String API_KEY = UUID.randomUUID().toString();

private static final String[] EXPERIMENT_IGNORED_FIELDS = new String[]{
"id", "datasetName", "datasetId", "feedbackScores", "traceCount", "createdAt", "lastUpdatedAt", "createdBy",
"id", "datasetId", "feedbackScores", "traceCount", "createdAt", "lastUpdatedAt", "createdBy",
"lastUpdatedBy"};

public static final String[] IGNORED_FIELDS = {"input", "output", "feedbackScores", "createdAt", "lastUpdatedAt",
Expand Down Expand Up @@ -1588,7 +1588,6 @@ private void assertIgnoredFields(
private void assertIgnoredFields(
Experiment actualExperiment, Experiment expectedExperiment, UUID expectedDatasetId) {
assertThat(actualExperiment.id()).isEqualTo(expectedExperiment.id());
assertThat(actualExperiment.datasetName()).isNull();
if (null != expectedDatasetId) {
assertThat(actualExperiment.datasetId()).isEqualTo(expectedDatasetId);
} else {
Expand Down

0 comments on commit 3163f2f

Please sign in to comment.