Skip to content

Commit

Permalink
OPIK-76 Add stream experiment items endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
andrescrz committed Sep 20, 2024
1 parent cacd942 commit 28b5ad0
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotBlank;
import lombok.Builder;

import java.util.UUID;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record ExperimentItemStreamRequest(
@NotBlank String experimentName,
@Min(1) @Max(2000) Integer limit,
UUID lastRetrievedId) {

@Override
public Integer limit() {
return limit == null ? 500 : limit;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.codahale.metrics.annotation.Timed;
import com.comet.opik.api.Experiment;
import com.comet.opik.api.ExperimentItem;
import com.comet.opik.api.ExperimentItemStreamRequest;
import com.comet.opik.api.ExperimentItemsBatch;
import com.comet.opik.api.ExperimentItemsDelete;
import com.comet.opik.api.ExperimentSearchCriteria;
Expand All @@ -12,9 +13,11 @@
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.utils.AsyncUtils;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.JsonNode;
import io.dropwizard.jersey.errors.ErrorMessage;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.ArraySchema;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
Expand All @@ -40,6 +43,7 @@
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;

import java.util.Set;
import java.util.UUID;
Expand Down Expand Up @@ -147,6 +151,24 @@ public Response getExperimentItem(@PathParam("id") UUID id) {
return Response.ok().entity(experimentItem).build();
}

@POST
@Path("/items/stream")
@Produces(MediaType.APPLICATION_OCTET_STREAM)
@Operation(operationId = "streamExperimentItems", summary = "Stream experiment items", description = "Stream experiment items", responses = {
@ApiResponse(responseCode = "200", description = "Experiment items stream or error during process", content = @Content(array = @ArraySchema(schema = @Schema(anyOf = {
ExperimentItem.class,
ErrorMessage.class
}), maxItems = 2000)))
})
public ChunkedOutput<JsonNode> streamExperimentItems(
@RequestBody(content = @Content(schema = @Schema(implementation = ExperimentItemStreamRequest.class))) @NotNull @Valid ExperimentItemStreamRequest request) {
var workspaceId = requestContext.get().getWorkspaceId();
log.info("Streaming experiment items by '{}', workspaceId '{}'", request, workspaceId);
var stream = experimentItemService.getExperimentItemsStream(request);
log.info("Streamed dataset items by '{}', workspaceId '{}'", request, workspaceId);
return stream;
}

@POST
@Path("/items")
@Operation(operationId = "createExperimentItems", summary = "Create experiment items", description = "Create experiment items", responses = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.comet.opik.api.FeedbackScoreAverage;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Preconditions;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Result;
Expand All @@ -16,6 +17,7 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.reactivestreams.Publisher;
import org.stringtemplate.v4.ST;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -338,6 +340,18 @@ SELECT count(id) as count
;
""";

private static final String FIND_BY_NAME = """
SELECT
*,
null AS feedback_scores,
null AS trace_count
FROM experiments
WHERE workspace_id = :workspace_id
AND ilike(name, CONCAT('%', :name, '%'))
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
""";

private static final String FIND_EXPERIMENT_AND_WORKSPACE_BY_DATASET_IDS = """
SELECT
id, workspace_id
Expand Down Expand Up @@ -485,6 +499,20 @@ private void bindSearchCriteria(Statement statement, ExperimentSearchCriteria cr
}
}

Flux<Experiment> findByName(String name) {
Preconditions.checkArgument(StringUtils.isNotBlank(name), "Argument 'name' must not be blank");
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> findByName(name, connection))
.flatMap(this::mapToDto);
}

private Publisher<? extends Result> findByName(String name, Connection connection) {
log.info("Finding experiment by name '{}'", name);
var statement = connection.createStatement(FIND_BY_NAME)
.bind("name", name);
return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
}

public Flux<WorkspaceAndResourceId> getExperimentWorkspaces(@NonNull Set<UUID> experimentIds) {
if (experimentIds.isEmpty()) {
return Flux.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ INSERT INTO experiment_items (
;
""";

private static final String STREAM = """
SELECT
*
FROM experiment_items
WHERE workspace_id = :workspace_id
AND experiment_id IN :experiment_ids
<if(lastRetrievedId)> AND id \\< :lastRetrievedId <endif>
ORDER BY experiment_id DESC, id DESC, last_updated_at DESC
LIMIT 1 BY id
LIMIT :limit
;
""";

private static final String DELETE = """
DELETE FROM experiment_items
WHERE id IN :ids
Expand Down Expand Up @@ -201,6 +214,29 @@ private Publisher<? extends Result> get(UUID id, Connection connection) {
return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
}

public Flux<ExperimentItem> getItems(Set<UUID> experimentIds, int limit, UUID lastRetrievedId) {
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> getItems(experimentIds, limit, lastRetrievedId, connection))
.flatMap(this::mapToExperimentItem);
}

private Publisher<? extends Result> getItems(
Set<UUID> experimentIds, int limit, UUID lastRetrievedId, Connection connection) {
log.info("Streaming experiment items by experimentIds count '{}', limit '{}', lastRetrievedId '{}'",
experimentIds.size(), limit, lastRetrievedId);
var template = new ST(STREAM);
if (lastRetrievedId != null) {
template.add("lastRetrievedId", lastRetrievedId);
}
var statement = connection.createStatement(template.render())
.bind("experiment_ids", experimentIds.toArray(UUID[]::new))
.bind("limit", limit);
if (lastRetrievedId != null) {
statement.bind("lastRetrievedId", lastRetrievedId);
}
return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
}

public Mono<Long> delete(Set<UUID> ids) {
Preconditions.checkArgument(CollectionUtils.isNotEmpty(ids),
"Argument 'ids' must not be empty");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package com.comet.opik.domain;

import com.comet.opik.api.Experiment;
import com.comet.opik.api.ExperimentItem;
import com.comet.opik.api.ExperimentItemStreamRequest;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Preconditions;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.inject.Singleton;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.NotFoundException;
Expand All @@ -12,10 +18,16 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.glassfish.jersey.server.ChunkedOutput;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

@Singleton
Expand All @@ -26,6 +38,7 @@ public class ExperimentItemService {
private final @NonNull ExperimentItemDAO experimentItemDAO;
private final @NonNull ExperimentService experimentService;
private final @NonNull DatasetItemDAO datasetItemDAO;
private final @NonNull Provider<RequestContext> requestContext;

public Mono<Void> create(Set<ExperimentItem> experimentItems) {
Preconditions.checkArgument(CollectionUtils.isNotEmpty(experimentItems),
Expand Down Expand Up @@ -116,12 +129,63 @@ private NotFoundException newNotFoundException(UUID id) {
return new NotFoundException(message);
}

public ChunkedOutput<JsonNode> getExperimentItemsStream(@NonNull ExperimentItemStreamRequest request) {
var outputStream = new ChunkedOutput<JsonNode>(JsonNode.class, "\r\n");
var workspaceId = requestContext.get().getWorkspaceId();
var userName = requestContext.get().getUserName();
var workspaceName = requestContext.get().getWorkspaceName();
log.info("Getting experiment items stream by '{}', workspaceId '{}'", request, workspaceId);
Schedulers.boundedElastic()
.schedule(() -> Mono
.fromCallable(() -> experimentService.findByName(request.experimentName()))
.subscribeOn(Schedulers.boundedElastic())
.flatMap(experiments -> experiments.map(Experiment::id).collect(Collectors.toUnmodifiableSet()))
.flatMapMany(
experimentIds -> experimentItemDAO.getItems(
experimentIds, request.limit(), request.lastRetrievedId()))
.doOnNext(item -> sendItem(item, outputStream))
.onErrorResume(throwable -> handleError(throwable, outputStream))
.doFinally(signalType -> close(outputStream))
.contextWrite(ctx -> ctx.put(RequestContext.USER_NAME, userName)
.put(RequestContext.WORKSPACE_NAME, workspaceName)
.put(RequestContext.WORKSPACE_ID, workspaceId))
.subscribe());
log.info("Got experiment items stream by '{}', workspaceId '{}'", request, workspaceId);
return outputStream;
}

private void sendItem(ExperimentItem item, ChunkedOutput<JsonNode> outputStream) {
try {
outputStream.write(JsonUtils.readTree(item));
} catch (IOException exception) {
throw new UncheckedIOException(exception);
}
}

private Flux<ExperimentItem> handleError(Throwable throwable, ChunkedOutput<JsonNode> outputStream) {
if (throwable instanceof TimeoutException) {
try {
outputStream.write(JsonUtils.readTree(new ErrorMessage(500, "Streaming operation timed out")));
} catch (IOException ioException) {
log.error("Failed to stream error to client", ioException);
}
}
return Flux.error(throwable);
}

private void close(ChunkedOutput<JsonNode> outputStream) {
try {
outputStream.close();
} catch (IOException exception) {
log.error("Error while closing experiment items stream", exception);
}
}

public Mono<Void> delete(@NonNull Set<UUID> ids) {
Preconditions.checkArgument(CollectionUtils.isNotEmpty(ids),
"Argument 'ids' must not be empty");

log.info("Deleting experiment items, count '{}'", ids.size());
return experimentItemDAO.delete(ids).then();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.comet.opik.api.ExperimentSearchCriteria;
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.google.common.base.Preconditions;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.ClientErrorException;
Expand All @@ -15,6 +16,7 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

Expand Down Expand Up @@ -60,6 +62,12 @@ public Mono<Experiment.ExperimentPage> find(
}));
}

public Flux<Experiment> findByName(String name) {
Preconditions.checkArgument(StringUtils.isNotBlank(name), "Argument 'name' must not be blank");
log.info("Finding experiments by name '{}'", name);
return experimentDAO.findByName(name);
}

public Mono<Experiment> getById(@NonNull UUID id) {
log.info("Getting experiment by id '{}'", id);
return experimentDAO.getById(id)
Expand Down
Loading

0 comments on commit 28b5ad0

Please sign in to comment.