Skip to content

Commit

Permalink
Merge branch 'main' into thiagohora/fix_config_file
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagohora authored Sep 20, 2024
2 parents 896706e + 90f5f49 commit 3a7faec
Show file tree
Hide file tree
Showing 23 changed files with 574 additions and 181 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.Max;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotBlank;
import lombok.Builder;

import java.util.UUID;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record ExperimentItemStreamRequest(
@NotBlank String experimentName,
@Min(1) @Max(2000) Integer limit,
UUID lastRetrievedId) {

@Override
public Integer limit() {
return limit == null ? 500 : limit;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.comet.opik.domain.DatasetService;
import com.comet.opik.domain.FeedbackScoreDAO;
import com.comet.opik.domain.IdGenerator;
import com.comet.opik.domain.Streamer;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.ratelimit.RateLimited;
import com.comet.opik.utils.AsyncUtils;
Expand Down Expand Up @@ -56,17 +57,11 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

import static com.comet.opik.api.Dataset.DatasetPage;
Expand All @@ -81,15 +76,14 @@
@Tag(name = "Datasets", description = "Dataset resources")
public class DatasetsResource {

private static final String STREAM_ERROR_LOG = "Error while streaming dataset items";

private static final TypeReference<List<UUID>> LIST_UUID_TYPE_REFERENCE = new TypeReference<>() {
};

private final @NonNull DatasetService service;
private final @NonNull DatasetItemService itemService;
private final @NonNull Provider<RequestContext> requestContext;
private final @NonNull IdGenerator idGenerator;
private final @NonNull Streamer streamer;

@GET
@Path("/{id}")
Expand Down Expand Up @@ -271,79 +265,23 @@ public Response getDatasetItems(
@ApiResponse(responseCode = "200", description = "Dataset items stream or error during process", content = @Content(array = @ArraySchema(schema = @Schema(anyOf = {
DatasetItem.class,
ErrorMessage.class
}), maxItems = 1000)))
}), maxItems = 2000)))
})
public ChunkedOutput<JsonNode> streamDatasetItems(
@RequestBody(content = @Content(schema = @Schema(implementation = DatasetItemStreamRequest.class))) @NotNull @Valid DatasetItemStreamRequest request) {

String workspaceId = requestContext.get().getWorkspaceId();
var workspaceId = requestContext.get().getWorkspaceId();
var userName = requestContext.get().getUserName();
var workspaceName = requestContext.get().getWorkspaceName();
log.info("Streaming dataset items by '{}' on workspaceId '{}'", request, workspaceId);
return getOutputStream(request, request.steamLimit());
}

private ChunkedOutput<JsonNode> getOutputStream(DatasetItemStreamRequest request, int limit) {

ChunkedOutput<JsonNode> outputStream = new ChunkedOutput<>(JsonNode.class, "\r\n");

String workspaceId = requestContext.get().getWorkspaceId();
String userName = requestContext.get().getUserName();
String workspaceName = requestContext.get().getWorkspaceName();

Schedulers
.boundedElastic()
.schedule(() -> {
Mono.fromCallable(() -> service.findByName(workspaceId, request.datasetName()))
.subscribeOn(Schedulers.boundedElastic())
.flatMapMany(
dataset -> itemService.getItems(dataset.id(), limit, request.lastRetrievedId()))
.doOnNext(item -> sendDatasetItems(item, outputStream))
.onErrorResume(ex -> errorHandling(ex, outputStream))
.doFinally(signalType -> closeOutput(outputStream))
.contextWrite(ctx -> ctx.put(RequestContext.USER_NAME, userName)
.put(RequestContext.WORKSPACE_NAME, workspaceName)
.put(RequestContext.WORKSPACE_ID, workspaceId))
.subscribe();

log.info("Streamed dataset items by '{}' on workspaceId '{}'", request, workspaceId);
});

var items = itemService.getItems(workspaceId, request)
.contextWrite(ctx -> ctx.put(RequestContext.USER_NAME, userName)
.put(RequestContext.WORKSPACE_NAME, workspaceName)
.put(RequestContext.WORKSPACE_ID, workspaceId));
var outputStream = streamer.getOutputStream(items);
log.info("Streamed dataset items by '{}' on workspaceId '{}'", request, workspaceId);
return outputStream;
}

private void closeOutput(ChunkedOutput<JsonNode> outputStream) {
try {
outputStream.close();
} catch (IOException e) {
log.error(STREAM_ERROR_LOG, e);
}
}

private <T> Flux<T> errorHandling(Throwable ex, ChunkedOutput<JsonNode> outputStream) {
if (ex instanceof TimeoutException timeoutException) {
try {
writeError(outputStream, "Streaming operation timed out");
} catch (IOException ioe) {
log.warn("Failed to send error to client", ioe);
}

return Flux.error(timeoutException);
}

return Flux.error(ex);
}

private void writeError(ChunkedOutput<JsonNode> outputStream, String errorMessage) throws IOException {
outputStream.write(JsonUtils.readTree(new ErrorMessage(500, errorMessage)));
}

private void sendDatasetItems(DatasetItem item, ChunkedOutput<JsonNode> writer) {
try {
writer.write(JsonUtils.readTree(item));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

@PUT
@Path("/items")
@Operation(operationId = "createOrUpdateDatasetItems", summary = "Create/update dataset items", description = "Create/update dataset items based on dataset item id", responses = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@
import com.codahale.metrics.annotation.Timed;
import com.comet.opik.api.Experiment;
import com.comet.opik.api.ExperimentItem;
import com.comet.opik.api.ExperimentItemStreamRequest;
import com.comet.opik.api.ExperimentItemsBatch;
import com.comet.opik.api.ExperimentItemsDelete;
import com.comet.opik.api.ExperimentSearchCriteria;
import com.comet.opik.domain.ExperimentItemService;
import com.comet.opik.domain.ExperimentService;
import com.comet.opik.domain.IdGenerator;
import com.comet.opik.domain.Streamer;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.ratelimit.RateLimited;
import com.comet.opik.utils.AsyncUtils;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.JsonNode;
import io.dropwizard.jersey.errors.ErrorMessage;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.ArraySchema;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
Expand All @@ -41,6 +45,7 @@
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;

import java.util.Set;
import java.util.UUID;
Expand All @@ -62,6 +67,7 @@ public class ExperimentsResource {
private final @NonNull ExperimentItemService experimentItemService;
private final @NonNull Provider<RequestContext> requestContext;
private final @NonNull IdGenerator idGenerator;
private final @NonNull Streamer streamer;

@GET
@Operation(operationId = "findExperiments", summary = "Find experiments", description = "Find experiments", responses = {
Expand Down Expand Up @@ -149,6 +155,30 @@ public Response getExperimentItem(@PathParam("id") UUID id) {
return Response.ok().entity(experimentItem).build();
}

@POST
@Path("/items/stream")
@Produces(MediaType.APPLICATION_OCTET_STREAM)
@Operation(operationId = "streamExperimentItems", summary = "Stream experiment items", description = "Stream experiment items", responses = {
@ApiResponse(responseCode = "200", description = "Experiment items stream or error during process", content = @Content(array = @ArraySchema(schema = @Schema(anyOf = {
ExperimentItem.class,
ErrorMessage.class
}), maxItems = 2000)))
})
public ChunkedOutput<JsonNode> streamExperimentItems(
@RequestBody(content = @Content(schema = @Schema(implementation = ExperimentItemStreamRequest.class))) @NotNull @Valid ExperimentItemStreamRequest request) {
var workspaceId = requestContext.get().getWorkspaceId();
var userName = requestContext.get().getUserName();
var workspaceName = requestContext.get().getWorkspaceName();
log.info("Streaming experiment items by '{}', workspaceId '{}'", request, workspaceId);
var items = experimentItemService.getExperimentItems(request)
.contextWrite(ctx -> ctx.put(RequestContext.USER_NAME, userName)
.put(RequestContext.WORKSPACE_NAME, workspaceName)
.put(RequestContext.WORKSPACE_ID, workspaceId));
var stream = streamer.getOutputStream(items);
log.info("Streamed experiment items by '{}', workspaceId '{}'", request, workspaceId);
return stream;
}

@POST
@Path("/items")
@Operation(operationId = "createExperimentItems", summary = "Create experiment items", description = "Create experiment items", responses = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ public Mono<DatasetItem> get(@NonNull UUID id) {
@Override
@Trace(dispatcher = true)
public Flux<DatasetItem> getItems(@NonNull UUID datasetId, int limit, UUID lastRetrievedId) {
log.info("Getting dataset items by datasetId '{}', limit '{}', lastRetrievedId '{}'",
datasetId, limit, lastRetrievedId);

ST template = new ST(SELECT_DATASET_ITEMS_STREAM);

if (lastRetrievedId != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.comet.opik.api.DatasetItem;
import com.comet.opik.api.DatasetItemBatch;
import com.comet.opik.api.DatasetItemSearchCriteria;
import com.comet.opik.api.DatasetItemStreamRequest;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.api.error.IdentifierMismatchException;
import com.comet.opik.infrastructure.auth.RequestContext;
Expand Down Expand Up @@ -42,8 +43,7 @@ public interface DatasetItemService {

Mono<DatasetItemPage> getItems(int page, int size, DatasetItemSearchCriteria datasetItemSearchCriteria);

Flux<DatasetItem> getItems(UUID datasetId, int limit, UUID lastRetrievedId);

Flux<DatasetItem> getItems(String workspaceId, DatasetItemStreamRequest request);
}

@Singleton
Expand Down Expand Up @@ -110,8 +110,11 @@ public Mono<DatasetItem> get(@NonNull UUID id) {

@Override
@Trace(dispatcher = true)
public Flux<DatasetItem> getItems(@NonNull UUID datasetId, int limit, UUID lastRetrievedId) {
return dao.getItems(datasetId, limit, lastRetrievedId);
public Flux<DatasetItem> getItems(@NonNull String workspaceId, @NonNull DatasetItemStreamRequest request) {
log.info("Getting dataset items by '{}' on workspaceId '{}'", request, workspaceId);
return Mono.fromCallable(() -> datasetService.findByName(workspaceId, request.datasetName()))
.subscribeOn(Schedulers.boundedElastic())
.flatMapMany(dataset -> dao.getItems(dataset.id(), request.steamLimit(), request.lastRetrievedId()));
}

private Mono<Long> saveBatch(DatasetItemBatch batch, UUID id) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.comet.opik.api.FeedbackScoreAverage;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Preconditions;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Result;
Expand All @@ -16,6 +17,7 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.reactivestreams.Publisher;
import org.stringtemplate.v4.ST;
import reactor.core.publisher.Flux;
Expand Down Expand Up @@ -338,6 +340,18 @@ SELECT count(id) as count
;
""";

private static final String FIND_BY_NAME = """
SELECT
*,
null AS feedback_scores,
null AS trace_count
FROM experiments
WHERE workspace_id = :workspace_id
AND ilike(name, CONCAT('%', :name, '%'))
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
""";

private static final String FIND_EXPERIMENT_AND_WORKSPACE_BY_DATASET_IDS = """
SELECT
id, workspace_id
Expand Down Expand Up @@ -485,6 +499,19 @@ private void bindSearchCriteria(Statement statement, ExperimentSearchCriteria cr
}
}

Flux<Experiment> findByName(String name) {
Preconditions.checkArgument(StringUtils.isNotBlank(name), "Argument 'name' must not be blank");
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> findByName(name, connection))
.flatMap(this::mapToDto);
}

private Publisher<? extends Result> findByName(String name, Connection connection) {
log.info("Finding experiment by name '{}'", name);
var statement = connection.createStatement(FIND_BY_NAME).bind("name", name);
return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
}

public Flux<WorkspaceAndResourceId> getExperimentWorkspaces(@NonNull Set<UUID> experimentIds) {
if (experimentIds.isEmpty()) {
return Flux.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,19 @@ INSERT INTO experiment_items (
;
""";

private static final String STREAM = """
SELECT
*
FROM experiment_items
WHERE workspace_id = :workspace_id
AND experiment_id IN :experiment_ids
<if(lastRetrievedId)> AND id \\< :lastRetrievedId <endif>
ORDER BY experiment_id DESC, id DESC, last_updated_at DESC
LIMIT 1 BY id
LIMIT :limit
;
""";

private static final String DELETE = """
DELETE FROM experiment_items
WHERE id IN :ids
Expand Down Expand Up @@ -201,6 +214,34 @@ private Publisher<? extends Result> get(UUID id, Connection connection) {
return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
}

public Flux<ExperimentItem> getItems(@NonNull Set<UUID> experimentIds, int limit, UUID lastRetrievedId) {
if (experimentIds.isEmpty()) {
log.info("Getting experiment items by empty experimentIds, limit '{}', lastRetrievedId '{}'",
limit, lastRetrievedId);
return Flux.empty();
}
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> getItems(experimentIds, limit, lastRetrievedId, connection))
.flatMap(this::mapToExperimentItem);
}

private Publisher<? extends Result> getItems(
Set<UUID> experimentIds, int limit, UUID lastRetrievedId, Connection connection) {
log.info("Getting experiment items by experimentIds count '{}', limit '{}', lastRetrievedId '{}'",
experimentIds.size(), limit, lastRetrievedId);
var template = new ST(STREAM);
if (lastRetrievedId != null) {
template.add("lastRetrievedId", lastRetrievedId);
}
var statement = connection.createStatement(template.render())
.bind("experiment_ids", experimentIds.toArray(UUID[]::new))
.bind("limit", limit);
if (lastRetrievedId != null) {
statement.bind("lastRetrievedId", lastRetrievedId);
}
return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
}

public Mono<Long> delete(Set<UUID> ids) {
Preconditions.checkArgument(CollectionUtils.isNotEmpty(ids),
"Argument 'ids' must not be empty");
Expand Down
Loading

0 comments on commit 3a7faec

Please sign in to comment.