diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemStreamRequest.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemStreamRequest.java new file mode 100644 index 0000000000..97fc11bcc2 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemStreamRequest.java @@ -0,0 +1,25 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; + +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record ExperimentItemStreamRequest( + @NotBlank String experimentName, + @Min(1) @Max(2000) Integer limit, + UUID lastRetrievedId) { + + @Override + public Integer limit() { + return limit == null ? 500 : limit; + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java index 5f80960d18..757310c117 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java @@ -15,6 +15,7 @@ import com.comet.opik.domain.DatasetService; import com.comet.opik.domain.FeedbackScoreDAO; import com.comet.opik.domain.IdGenerator; +import com.comet.opik.domain.Streamer; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; @@ -56,17 +57,11 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.glassfish.jersey.server.ChunkedOutput; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; -import java.io.IOException; -import java.io.UncheckedIOException; import java.net.URI; import java.util.List; import java.util.Set; import java.util.UUID; -import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; import static com.comet.opik.api.Dataset.DatasetPage; @@ -81,8 +76,6 @@ @Tag(name = "Datasets", description = "Dataset resources") public class DatasetsResource { - private static final String STREAM_ERROR_LOG = "Error while streaming dataset items"; - private static final TypeReference> LIST_UUID_TYPE_REFERENCE = new TypeReference<>() { }; @@ -90,6 +83,7 @@ public class DatasetsResource { private final @NonNull DatasetItemService itemService; private final @NonNull Provider requestContext; private final @NonNull IdGenerator idGenerator; + private final @NonNull Streamer streamer; @GET @Path("/{id}") @@ -271,79 +265,23 @@ public Response getDatasetItems( @ApiResponse(responseCode = "200", description = "Dataset items stream or error during process", content = @Content(array = @ArraySchema(schema = @Schema(anyOf = { DatasetItem.class, ErrorMessage.class - }), maxItems = 1000))) + }), maxItems = 2000))) }) public ChunkedOutput streamDatasetItems( @RequestBody(content = @Content(schema = @Schema(implementation = DatasetItemStreamRequest.class))) @NotNull @Valid DatasetItemStreamRequest request) { - - String workspaceId = requestContext.get().getWorkspaceId(); + var workspaceId = requestContext.get().getWorkspaceId(); + var userName = requestContext.get().getUserName(); + var workspaceName = requestContext.get().getWorkspaceName(); log.info("Streaming dataset items by '{}' on workspaceId '{}'", request, workspaceId); - return getOutputStream(request, request.steamLimit()); - } - - private ChunkedOutput getOutputStream(DatasetItemStreamRequest request, int limit) { - - ChunkedOutput outputStream = new ChunkedOutput<>(JsonNode.class, "\r\n"); - - String workspaceId = requestContext.get().getWorkspaceId(); - String userName = requestContext.get().getUserName(); - String workspaceName = requestContext.get().getWorkspaceName(); - - Schedulers - .boundedElastic() - .schedule(() -> { - Mono.fromCallable(() -> service.findByName(workspaceId, request.datasetName())) - .subscribeOn(Schedulers.boundedElastic()) - .flatMapMany( - dataset -> itemService.getItems(dataset.id(), limit, request.lastRetrievedId())) - .doOnNext(item -> sendDatasetItems(item, outputStream)) - .onErrorResume(ex -> errorHandling(ex, outputStream)) - .doFinally(signalType -> closeOutput(outputStream)) - .contextWrite(ctx -> ctx.put(RequestContext.USER_NAME, userName) - .put(RequestContext.WORKSPACE_NAME, workspaceName) - .put(RequestContext.WORKSPACE_ID, workspaceId)) - .subscribe(); - - log.info("Streamed dataset items by '{}' on workspaceId '{}'", request, workspaceId); - }); - + var items = itemService.getItems(workspaceId, request) + .contextWrite(ctx -> ctx.put(RequestContext.USER_NAME, userName) + .put(RequestContext.WORKSPACE_NAME, workspaceName) + .put(RequestContext.WORKSPACE_ID, workspaceId)); + var outputStream = streamer.getOutputStream(items); + log.info("Streamed dataset items by '{}' on workspaceId '{}'", request, workspaceId); return outputStream; } - private void closeOutput(ChunkedOutput outputStream) { - try { - outputStream.close(); - } catch (IOException e) { - log.error(STREAM_ERROR_LOG, e); - } - } - - private Flux errorHandling(Throwable ex, ChunkedOutput outputStream) { - if (ex instanceof TimeoutException timeoutException) { - try { - writeError(outputStream, "Streaming operation timed out"); - } catch (IOException ioe) { - log.warn("Failed to send error to client", ioe); - } - - return Flux.error(timeoutException); - } - - return Flux.error(ex); - } - - private void writeError(ChunkedOutput outputStream, String errorMessage) throws IOException { - outputStream.write(JsonUtils.readTree(new ErrorMessage(500, errorMessage))); - } - - private void sendDatasetItems(DatasetItem item, ChunkedOutput writer) { - try { - writer.write(JsonUtils.readTree(item)); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - @PUT @Path("/items") @Operation(operationId = "createOrUpdateDatasetItems", summary = "Create/update dataset items", description = "Create/update dataset items based on dataset item id", responses = { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java index 31caf2a3cb..9ce949c97e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java @@ -3,19 +3,23 @@ import com.codahale.metrics.annotation.Timed; import com.comet.opik.api.Experiment; import com.comet.opik.api.ExperimentItem; +import com.comet.opik.api.ExperimentItemStreamRequest; import com.comet.opik.api.ExperimentItemsBatch; import com.comet.opik.api.ExperimentItemsDelete; import com.comet.opik.api.ExperimentSearchCriteria; import com.comet.opik.domain.ExperimentItemService; import com.comet.opik.domain.ExperimentService; import com.comet.opik.domain.IdGenerator; +import com.comet.opik.domain.Streamer; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.ratelimit.RateLimited; import com.comet.opik.utils.AsyncUtils; import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.JsonNode; import io.dropwizard.jersey.errors.ErrorMessage; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.media.ArraySchema; import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; @@ -41,6 +45,7 @@ import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.glassfish.jersey.server.ChunkedOutput; import java.util.Set; import java.util.UUID; @@ -62,6 +67,7 @@ public class ExperimentsResource { private final @NonNull ExperimentItemService experimentItemService; private final @NonNull Provider requestContext; private final @NonNull IdGenerator idGenerator; + private final @NonNull Streamer streamer; @GET @Operation(operationId = "findExperiments", summary = "Find experiments", description = "Find experiments", responses = { @@ -149,6 +155,30 @@ public Response getExperimentItem(@PathParam("id") UUID id) { return Response.ok().entity(experimentItem).build(); } + @POST + @Path("/items/stream") + @Produces(MediaType.APPLICATION_OCTET_STREAM) + @Operation(operationId = "streamExperimentItems", summary = "Stream experiment items", description = "Stream experiment items", responses = { + @ApiResponse(responseCode = "200", description = "Experiment items stream or error during process", content = @Content(array = @ArraySchema(schema = @Schema(anyOf = { + ExperimentItem.class, + ErrorMessage.class + }), maxItems = 2000))) + }) + public ChunkedOutput streamExperimentItems( + @RequestBody(content = @Content(schema = @Schema(implementation = ExperimentItemStreamRequest.class))) @NotNull @Valid ExperimentItemStreamRequest request) { + var workspaceId = requestContext.get().getWorkspaceId(); + var userName = requestContext.get().getUserName(); + var workspaceName = requestContext.get().getWorkspaceName(); + log.info("Streaming experiment items by '{}', workspaceId '{}'", request, workspaceId); + var items = experimentItemService.getExperimentItems(request) + .contextWrite(ctx -> ctx.put(RequestContext.USER_NAME, userName) + .put(RequestContext.WORKSPACE_NAME, workspaceName) + .put(RequestContext.WORKSPACE_ID, workspaceId)); + var stream = streamer.getOutputStream(items); + log.info("Streamed experiment items by '{}', workspaceId '{}'", request, workspaceId); + return stream; + } + @POST @Path("/items") @Operation(operationId = "createExperimentItems", summary = "Create experiment items", description = "Create experiment items", responses = { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java index e94abdad59..c742802aca 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java @@ -499,6 +499,9 @@ public Mono get(@NonNull UUID id) { @Override @Trace(dispatcher = true) public Flux getItems(@NonNull UUID datasetId, int limit, UUID lastRetrievedId) { + log.info("Getting dataset items by datasetId '{}', limit '{}', lastRetrievedId '{}'", + datasetId, limit, lastRetrievedId); + ST template = new ST(SELECT_DATASET_ITEMS_STREAM); if (lastRetrievedId != null) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java index da5b70fee9..e1a2ec04d2 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemService.java @@ -4,6 +4,7 @@ import com.comet.opik.api.DatasetItem; import com.comet.opik.api.DatasetItemBatch; import com.comet.opik.api.DatasetItemSearchCriteria; +import com.comet.opik.api.DatasetItemStreamRequest; import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.api.error.IdentifierMismatchException; import com.comet.opik.infrastructure.auth.RequestContext; @@ -42,8 +43,7 @@ public interface DatasetItemService { Mono getItems(int page, int size, DatasetItemSearchCriteria datasetItemSearchCriteria); - Flux getItems(UUID datasetId, int limit, UUID lastRetrievedId); - + Flux getItems(String workspaceId, DatasetItemStreamRequest request); } @Singleton @@ -110,8 +110,11 @@ public Mono get(@NonNull UUID id) { @Override @Trace(dispatcher = true) - public Flux getItems(@NonNull UUID datasetId, int limit, UUID lastRetrievedId) { - return dao.getItems(datasetId, limit, lastRetrievedId); + public Flux getItems(@NonNull String workspaceId, @NonNull DatasetItemStreamRequest request) { + log.info("Getting dataset items by '{}' on workspaceId '{}'", request, workspaceId); + return Mono.fromCallable(() -> datasetService.findByName(workspaceId, request.datasetName())) + .subscribeOn(Schedulers.boundedElastic()) + .flatMapMany(dataset -> dao.getItems(dataset.id(), request.steamLimit(), request.lastRetrievedId())); } private Mono saveBatch(DatasetItemBatch batch, UUID id) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentDAO.java index cb5e7558e0..c467b0ff87 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentDAO.java @@ -5,6 +5,7 @@ import com.comet.opik.api.FeedbackScoreAverage; import com.comet.opik.utils.JsonUtils; import com.fasterxml.jackson.databind.JsonNode; +import com.google.common.base.Preconditions; import io.r2dbc.spi.Connection; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.Result; @@ -16,6 +17,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.reactivestreams.Publisher; import org.stringtemplate.v4.ST; import reactor.core.publisher.Flux; @@ -338,6 +340,18 @@ SELECT count(id) as count ; """; + private static final String FIND_BY_NAME = """ + SELECT + *, + null AS feedback_scores, + null AS trace_count + FROM experiments + WHERE workspace_id = :workspace_id + AND ilike(name, CONCAT('%', :name, '%')) + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + """; + private static final String FIND_EXPERIMENT_AND_WORKSPACE_BY_DATASET_IDS = """ SELECT id, workspace_id @@ -485,6 +499,19 @@ private void bindSearchCriteria(Statement statement, ExperimentSearchCriteria cr } } + Flux findByName(String name) { + Preconditions.checkArgument(StringUtils.isNotBlank(name), "Argument 'name' must not be blank"); + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> findByName(name, connection)) + .flatMap(this::mapToDto); + } + + private Publisher findByName(String name, Connection connection) { + log.info("Finding experiment by name '{}'", name); + var statement = connection.createStatement(FIND_BY_NAME).bind("name", name); + return makeFluxContextAware(bindWorkspaceIdToFlux(statement)); + } + public Flux getExperimentWorkspaces(@NonNull Set experimentIds) { if (experimentIds.isEmpty()) { return Flux.empty(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemDAO.java index 08594f56af..d437a7d309 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemDAO.java @@ -83,6 +83,19 @@ INSERT INTO experiment_items ( ; """; + private static final String STREAM = """ + SELECT + * + FROM experiment_items + WHERE workspace_id = :workspace_id + AND experiment_id IN :experiment_ids + AND id \\< :lastRetrievedId + ORDER BY experiment_id DESC, id DESC, last_updated_at DESC + LIMIT 1 BY id + LIMIT :limit + ; + """; + private static final String DELETE = """ DELETE FROM experiment_items WHERE id IN :ids @@ -201,6 +214,34 @@ private Publisher get(UUID id, Connection connection) { return makeFluxContextAware(bindWorkspaceIdToFlux(statement)); } + public Flux getItems(@NonNull Set experimentIds, int limit, UUID lastRetrievedId) { + if (experimentIds.isEmpty()) { + log.info("Getting experiment items by empty experimentIds, limit '{}', lastRetrievedId '{}'", + limit, lastRetrievedId); + return Flux.empty(); + } + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> getItems(experimentIds, limit, lastRetrievedId, connection)) + .flatMap(this::mapToExperimentItem); + } + + private Publisher getItems( + Set experimentIds, int limit, UUID lastRetrievedId, Connection connection) { + log.info("Getting experiment items by experimentIds count '{}', limit '{}', lastRetrievedId '{}'", + experimentIds.size(), limit, lastRetrievedId); + var template = new ST(STREAM); + if (lastRetrievedId != null) { + template.add("lastRetrievedId", lastRetrievedId); + } + var statement = connection.createStatement(template.render()) + .bind("experiment_ids", experimentIds.toArray(UUID[]::new)) + .bind("limit", limit); + if (lastRetrievedId != null) { + statement.bind("lastRetrievedId", lastRetrievedId); + } + return makeFluxContextAware(bindWorkspaceIdToFlux(statement)); + } + public Mono delete(Set ids) { Preconditions.checkArgument(CollectionUtils.isNotEmpty(ids), "Argument 'ids' must not be empty"); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemService.java index 39672965dc..6bffae7f83 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemService.java @@ -1,6 +1,8 @@ package com.comet.opik.domain; +import com.comet.opik.api.Experiment; import com.comet.opik.api.ExperimentItem; +import com.comet.opik.api.ExperimentItemStreamRequest; import com.comet.opik.infrastructure.auth.RequestContext; import com.google.common.base.Preconditions; import jakarta.inject.Inject; @@ -12,7 +14,9 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import java.util.Set; import java.util.UUID; @@ -116,6 +120,15 @@ private NotFoundException newNotFoundException(UUID id) { return new NotFoundException(message); } + public Flux getExperimentItems(@NonNull ExperimentItemStreamRequest request) { + log.info("Getting experiment items by '{}'", request); + return experimentService.findByName(request.experimentName()) + .subscribeOn(Schedulers.boundedElastic()) + .collect(Collectors.mapping(Experiment::id, Collectors.toUnmodifiableSet())) + .flatMapMany(experimentIds -> experimentItemDAO.getItems( + experimentIds, request.limit(), request.lastRetrievedId())); + } + public Mono delete(@NonNull Set ids) { Preconditions.checkArgument(CollectionUtils.isNotEmpty(ids), "Argument 'ids' must not be empty"); @@ -123,5 +136,4 @@ public Mono delete(@NonNull Set ids) { log.info("Deleting experiment items, count '{}'", ids.size()); return experimentItemDAO.delete(ids).then(); } - } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentService.java index 6576742993..3e06bac95b 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentService.java @@ -6,6 +6,7 @@ import com.comet.opik.api.ExperimentSearchCriteria; import com.comet.opik.api.error.EntityAlreadyExistsException; import com.comet.opik.infrastructure.auth.RequestContext; +import com.google.common.base.Preconditions; import jakarta.inject.Inject; import jakarta.inject.Singleton; import jakarta.ws.rs.ClientErrorException; @@ -15,6 +16,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -60,6 +62,12 @@ public Mono find( })); } + public Flux findByName(String name) { + Preconditions.checkArgument(StringUtils.isNotBlank(name), "Argument 'name' must not be blank"); + log.info("Finding experiments by name '{}'", name); + return experimentDAO.findByName(name); + } + public Mono getById(@NonNull UUID id) { log.info("Getting experiment by id '{}'", id); return experimentDAO.getById(id) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/Streamer.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/Streamer.java new file mode 100644 index 0000000000..5531b1fb88 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/Streamer.java @@ -0,0 +1,57 @@ +package com.comet.opik.domain; + +import com.comet.opik.utils.JsonUtils; +import com.fasterxml.jackson.databind.JsonNode; +import io.dropwizard.jersey.errors.ErrorMessage; +import jakarta.inject.Singleton; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.glassfish.jersey.server.ChunkedOutput; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.concurrent.TimeoutException; + +@Singleton +@Slf4j +public class Streamer { + + public ChunkedOutput getOutputStream(@NonNull Flux flux) { + var outputStream = new ChunkedOutput(JsonNode.class, "\r\n"); + Schedulers.boundedElastic() + .schedule(() -> flux.doOnNext(item -> sendItem(item, outputStream)) + .onErrorResume(throwable -> handleError(throwable, outputStream)) + .doFinally(signalType -> close(outputStream)) + .subscribe()); + return outputStream; + } + + private void sendItem(T item, ChunkedOutput outputStream) { + try { + outputStream.write(JsonUtils.readTree(item)); + } catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + private Flux handleError(Throwable throwable, ChunkedOutput outputStream) { + if (throwable instanceof TimeoutException) { + try { + outputStream.write(JsonUtils.readTree(new ErrorMessage(500, "Streaming operation timed out"))); + } catch (IOException ioException) { + log.error("Failed to stream error message to client", ioException); + } + } + return Flux.error(throwable); + } + + private void close(ChunkedOutput outputStream) { + try { + outputStream.close(); + } catch (IOException exception) { + log.error("Error while closing output stream", exception); + } + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceIntegrationTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceIntegrationTest.java index a90e1f95fe..4ae4e52348 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceIntegrationTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceIntegrationTest.java @@ -1,10 +1,10 @@ package com.comet.opik.api.resources.v1.priv; -import com.comet.opik.api.Dataset; import com.comet.opik.api.DatasetItem; import com.comet.opik.api.DatasetItemStreamRequest; import com.comet.opik.domain.DatasetItemService; import com.comet.opik.domain.DatasetService; +import com.comet.opik.domain.Streamer; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.json.JsonNodeMessageBodyWriter; import com.comet.opik.podam.PodamFactoryUtils; @@ -31,8 +31,6 @@ import static com.comet.opik.domain.ProjectService.DEFAULT_USER; import static com.comet.opik.domain.ProjectService.DEFAULT_WORKSPACE_NAME; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.eq; import static org.mockito.Mockito.when; @ExtendWith(DropwizardExtensionsSupport.class) @@ -44,7 +42,8 @@ class DatasetsResourceIntegrationTest { private static final TimeBasedEpochGenerator timeBasedGenerator = Generators.timeBasedEpochGenerator(); private static final ResourceExtension EXT = ResourceExtension.builder() - .addResource(new DatasetsResource(service, itemService, () -> requestContext, timeBasedGenerator::generate)) + .addResource(new DatasetsResource( + service, itemService, () -> requestContext, timeBasedGenerator::generate, new Streamer())) .addProvider(JsonNodeMessageBodyWriter.class) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .build(); @@ -54,12 +53,7 @@ class DatasetsResourceIntegrationTest { @Test void testStreamErrorHandling() { var datasetName = "test"; - String workspaceId = UUID.randomUUID().toString(); - - Dataset dataset = Dataset.builder().id(UUID.randomUUID()).name(datasetName).build(); - - when(service.findByName(workspaceId, datasetName)) - .thenReturn(dataset); + var workspaceId = UUID.randomUUID().toString(); when(requestContext.getUserName()) .thenReturn(DEFAULT_USER); @@ -78,13 +72,15 @@ void testStreamErrorHandling() { sink.error(new TimeoutException("Connection timed out")); }); - when(itemService.getItems(eq(dataset.id()), eq(500), any())) + var request = DatasetItemStreamRequest.builder().datasetName(datasetName).steamLimit(500).build(); + + when(itemService.getItems(workspaceId, request)) .thenReturn(Flux.defer(() -> itemFlux)); try (var response = EXT.target("/v1/private/datasets/items/stream") .request() .header("workspace", DEFAULT_WORKSPACE_NAME) - .post(Entity.json(DatasetItemStreamRequest.builder().datasetName(datasetName).build()))) { + .post(Entity.json(request))) { try (var inputStream = response.readEntity(new GenericType>() { })) { diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java index 3b4e852e8b..f2cf8a0fbf 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java @@ -4,6 +4,7 @@ import com.comet.opik.api.DatasetItemBatch; import com.comet.opik.api.Experiment; import com.comet.opik.api.ExperimentItem; +import com.comet.opik.api.ExperimentItemStreamRequest; import com.comet.opik.api.ExperimentItemsBatch; import com.comet.opik.api.ExperimentItemsDelete; import com.comet.opik.api.FeedbackScore; @@ -22,16 +23,22 @@ import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.domain.FeedbackScoreMapper; import com.comet.opik.podam.PodamFactoryUtils; +import com.comet.opik.utils.JsonUtils; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.uuid.Generators; import com.fasterxml.uuid.impl.TimeBasedEpochGenerator; import com.github.tomakehurst.wiremock.client.WireMock; import com.redis.testcontainers.RedisContainer; import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.core.GenericType; import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.StringUtils; import org.assertj.core.api.recursive.comparison.RecursiveComparisonConfiguration; +import org.glassfish.jersey.client.ChunkedInput; import org.jdbi.v3.core.Jdbi; import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterAll; @@ -57,7 +64,9 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.sql.SQLException; +import java.util.ArrayList; import java.util.Collection; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Set; @@ -95,13 +104,19 @@ class ExperimentsResourceTest { private static final String[] EXPERIMENT_IGNORED_FIELDS = new String[]{ "id", "datasetId", "name", "feedbackScores", "traceCount", "createdAt", "lastUpdatedAt", "createdBy", "lastUpdatedBy"}; - public static final String[] IGNORED_FIELDS = {"input", "output", "feedbackScores", "createdAt", "lastUpdatedAt", - "createdBy", "lastUpdatedBy"}; + public static final String[] ITEM_IGNORED_FIELDS = {"input", "output", "feedbackScores", "createdAt", + "lastUpdatedAt", "createdBy", "lastUpdatedBy"}; private static final String WORKSPACE_ID = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); private static final String TEST_WORKSPACE = UUID.randomUUID().toString(); + private static final GenericType> CHUNKED_INPUT_STRING_GENERIC_TYPE = new GenericType<>() { + }; + + private static final TypeReference EXPERIMENT_ITEM_TYPE_REFERENCE = new TypeReference<>() { + }; + private static final TimeBasedEpochGenerator GENERATOR = Generators.timeBasedEpochGenerator(); private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); @@ -153,10 +168,10 @@ private static void mockTargetWorkspace(String apiKey, String workspaceName, Str AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); } - private static void mockSessionCookieTargetWorkspace(String sessionToken, String workspaceName, - String workspaceId) { - AuthTestUtils.mockSessionCookieTargetWorkspace(wireMock.server(), sessionToken, workspaceName, workspaceId, - USER); + private static void mockSessionCookieTargetWorkspace( + String sessionToken, String workspaceName, String workspaceId) { + AuthTestUtils.mockSessionCookieTargetWorkspace( + wireMock.server(), sessionToken, workspaceName, workspaceId, USER); } @AfterAll @@ -181,7 +196,6 @@ Stream credentials() { @BeforeEach void setUp() { - wireMock.server().stubFor( post(urlPathEqualTo("/opik/auth")) .withHeader(HttpHeaders.AUTHORIZATION, equalTo(fakeApikey)) @@ -294,8 +308,7 @@ void deleteExperimentItems__whenApiKeyIsPresent__thenReturnProperResponse(String createAndAssert(createRequest, okApikey, workspaceName); - createRequest.experimentItems() - .forEach(item -> ExperimentsResourceTest.this.getAndAssert(item, workspaceName, okApikey)); + createRequest.experimentItems().forEach(item -> getAndAssert(item, workspaceName, okApikey)); var ids = createRequest.experimentItems().stream().map(ExperimentItem::id).collect(Collectors.toSet()); var deleteRequest = ExperimentItemsDelete.builder().ids(ids).build(); @@ -346,7 +359,6 @@ void createExperimentItems__whenApiKeyIsPresent__thenReturnProperResponse(String @ParameterizedTest @MethodSource("credentials") void getExperimentItemById__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean success) { - var workspaceName = UUID.randomUUID().toString(); var expectedExperimentItem = podamFactory.manufacturePojo(ExperimentItem.class); var id = expectedExperimentItem.id(); @@ -408,8 +420,8 @@ void setUp() { @ParameterizedTest @MethodSource("credentials") - void getById__whenSessionTokenIsPresent__thenReturnProperResponse(String currentSessionToken, boolean success, - String workspaceName) { + void getById__whenSessionTokenIsPresent__thenReturnProperResponse( + String currentSessionToken, boolean success, String workspaceName) { var expectedExperiment = podamFactory.manufacturePojo(Experiment.class); mockTargetWorkspace(API_KEY, workspaceName, WORKSPACE_ID); @@ -433,8 +445,8 @@ void getById__whenSessionTokenIsPresent__thenReturnProperResponse(String current @ParameterizedTest @MethodSource("credentials") - void create__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, boolean success, - String workspaceName) { + void create__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { var expectedExperiment = podamFactory.manufacturePojo(Experiment.class); mockTargetWorkspace(API_KEY, sessionToken, WORKSPACE_ID); @@ -456,8 +468,8 @@ void create__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionT @ParameterizedTest @MethodSource("credentials") - void find__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, boolean success, - String workspaceName) { + void find__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { var workspaceId = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -495,15 +507,14 @@ void find__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionTok @ParameterizedTest @MethodSource("credentials") - void deleteExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, - boolean success, String workspaceName) { + void deleteExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { mockTargetWorkspace(API_KEY, workspaceName, WORKSPACE_ID); var createRequest = podamFactory.manufacturePojo(ExperimentItemsBatch.class); createAndAssert(createRequest, API_KEY, workspaceName); - createRequest.experimentItems() - .forEach(item -> ExperimentsResourceTest.this.getAndAssert(item, workspaceName, API_KEY)); + createRequest.experimentItems().forEach(item -> getAndAssert(item, workspaceName, API_KEY)); var ids = createRequest.experimentItems().stream().map(ExperimentItem::id).collect(Collectors.toSet()); var deleteRequest = ExperimentItemsDelete.builder().ids(ids).build(); @@ -527,8 +538,8 @@ void deleteExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse( @ParameterizedTest @MethodSource("credentials") - void createExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, - boolean success, String workspaceName) { + void createExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { var request = podamFactory.manufacturePojo(ExperimentItemsBatch.class); @@ -550,8 +561,8 @@ void createExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse( @ParameterizedTest @MethodSource("credentials") - void getExperimentItemById__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, - boolean success, String workspaceName) { + void getExperimentItemById__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { mockTargetWorkspace(API_KEY, workspaceName, WORKSPACE_ID); @@ -602,12 +613,11 @@ void findByDatasetId() { .datasetName(datasetName) .build()) .toList(); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, - apiKey, workspaceName)); + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var unexpectedExperiments = List.of(podamFactory.manufacturePojo(Experiment.class)); - unexpectedExperiments.forEach(expectedExperiment -> ExperimentsResourceTest.this - .createAndAssert(expectedExperiment, apiKey, workspaceName)); + unexpectedExperiments + .forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var pageSize = experiments.size() - 2; var datasetId = getAndAssert(experiments.getFirst().id(), experiments.getFirst(), workspaceName, apiKey) @@ -651,12 +661,12 @@ void findByName(String name, String nameQueryParam) { .name(name) .build()) .toList(); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var unexpectedExperiments = List.of(podamFactory.manufacturePojo(Experiment.class)); - unexpectedExperiments.forEach(expectedExperiment -> ExperimentsResourceTest.this - .createAndAssert(expectedExperiment, apiKey, workspaceName)); + unexpectedExperiments + .forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var pageSize = experiments.size() - 2; UUID datasetId = null; @@ -672,7 +682,6 @@ void findByName(String name, String nameQueryParam) { @Test void findByDatasetIdAndName() { - var workspaceName = UUID.randomUUID().toString(); var workspaceId = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -690,12 +699,12 @@ void findByDatasetIdAndName() { .metadata(null) .build()) .toList(); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var unexpectedExperiments = List.of(podamFactory.manufacturePojo(Experiment.class)); - unexpectedExperiments.forEach(expectedExperiment -> ExperimentsResourceTest.this - .createAndAssert(expectedExperiment, apiKey, workspaceName)); + unexpectedExperiments + .forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var pageSize = experiments.size() - 2; var datasetId = getAndAssert(experiments.getFirst().id(), experiments.getFirst(), workspaceName, apiKey) @@ -720,7 +729,7 @@ void findAll() { var experiments = PodamFactoryUtils.manufacturePojoList(podamFactory, Experiment.class); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var page = 1; @@ -760,8 +769,7 @@ void findAllAndCalculateFeedbackAvg() { .build()) .toList(); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, - apiKey, workspaceName)); + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var noScoreExperiment = podamFactory.manufacturePojo(Experiment.class); createAndAssert(noScoreExperiment, apiKey, workspaceName); @@ -1007,8 +1015,8 @@ private void deleteTrace(UUID id, String apiKey, String workspaceName) { } } - private @NotNull Map> getExpectedScoresPerExperiment(List experiments, - List experimentItems) { + private @NotNull Map> getExpectedScoresPerExperiment( + List experiments, List experimentItems) { return experiments.stream() .map(experiment -> Map.entry(experiment.id(), experimentItems .stream() @@ -1328,7 +1336,6 @@ void getNotFound() { @Test void createAndGetWithDeletedTrace() { - var workspaceName = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); var workspaceId = UUID.randomUUID().toString(); @@ -1426,7 +1433,6 @@ void createAndGetWithDeletedTrace() { } private ExperimentItemsBatch addRandomExperiments(List experimentItems) { - // When storing the experiment items in batch, adding some more unrelated random ones var experimentItemsBatch = podamFactory.manufacturePojo(ExperimentItemsBatch.class); experimentItemsBatch = experimentItemsBatch.toBuilder() @@ -1440,13 +1446,11 @@ private ExperimentItemsBatch addRandomExperiments(List experimen private Map getScoresMap(Experiment experiment) { List feedbackScores = experiment.feedbackScores(); - if (feedbackScores != null) { return feedbackScores .stream() .collect(Collectors.toMap(FeedbackScoreAverage::name, FeedbackScoreAverage::value)); } - return null; } @@ -1622,6 +1626,108 @@ void getById() { } } + @Nested + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class StreamExperimentItems { + + @Test + void streamByExperimentName() { + var apiKey = UUID.randomUUID().toString(); + var workspaceName = RandomStringUtils.randomAlphanumeric(10); + var workspaceId = UUID.randomUUID().toString(); + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var experiment1 = podamFactory.manufacturePojo(Experiment.class); + createAndAssert(experiment1, apiKey, workspaceName); + + var experiment2 = podamFactory.manufacturePojo(Experiment.class).toBuilder() + .name(experiment1.name().substring(1, experiment1.name().length() - 1).toLowerCase()) + .build(); + createAndAssert(experiment2, apiKey, workspaceName); + + var experiment3 = podamFactory.manufacturePojo(Experiment.class); + createAndAssert(experiment3, apiKey, workspaceName); + + var experimentItems1 = PodamFactoryUtils.manufacturePojoList(podamFactory, ExperimentItem.class).stream() + .map(experimentItem -> experimentItem.toBuilder().experimentId(experiment1.id()).build()) + .collect(Collectors.toUnmodifiableSet()); + var createRequest1 = ExperimentItemsBatch.builder().experimentItems(experimentItems1).build(); + createAndAssert(createRequest1, apiKey, workspaceName); + + var experimentItems2 = PodamFactoryUtils.manufacturePojoList(podamFactory, ExperimentItem.class).stream() + .map(experimentItem -> experimentItem.toBuilder().experimentId(experiment2.id()).build()) + .collect(Collectors.toUnmodifiableSet()); + var createRequest2 = ExperimentItemsBatch.builder().experimentItems(experimentItems2).build(); + createAndAssert(createRequest2, apiKey, workspaceName); + + var experimentItems3 = PodamFactoryUtils.manufacturePojoList(podamFactory, ExperimentItem.class).stream() + .map(experimentItem -> experimentItem.toBuilder().experimentId(experiment3.id()).build()) + .collect(Collectors.toUnmodifiableSet()); + var createRequest3 = ExperimentItemsBatch.builder().experimentItems(experimentItems3).build(); + createAndAssert(createRequest3, apiKey, workspaceName); + + var size = experimentItems1.size() + experimentItems2.size(); + var limit = size / 2; + + var expectedExperimentItems = Stream.concat(experimentItems1.stream(), experimentItems2.stream()) + .sorted(Comparator.comparing(ExperimentItem::experimentId).thenComparing(ExperimentItem::id)) + .toList() + .reversed(); + + var expectedExperimentItems1 = expectedExperimentItems.subList(0, limit); + var expectedExperimentItems2 = expectedExperimentItems.subList(limit, size); + + var streamRequest1 = ExperimentItemStreamRequest.builder() + .experimentName(experiment2.name()) + .limit(limit) + .build(); + var unexpectedExperimentItems1 = Stream.concat(expectedExperimentItems2.stream(), experimentItems3.stream()) + .toList(); + streamAndAssert( + streamRequest1, expectedExperimentItems1, unexpectedExperimentItems1, apiKey, workspaceName); + + var streamRequest2 = ExperimentItemStreamRequest.builder() + .experimentName(experiment2.name()) + .lastRetrievedId(expectedExperimentItems1.getLast().id()) + .build(); + var unexpectedExperimentItems2 = Stream.concat(expectedExperimentItems1.stream(), experimentItems3.stream()) + .toList(); + streamAndAssert( + streamRequest2, expectedExperimentItems2, unexpectedExperimentItems2, apiKey, workspaceName); + } + + @Test + void streamByExperimentNameWithNoItems() { + var apiKey = UUID.randomUUID().toString(); + var workspaceName = RandomStringUtils.randomAlphanumeric(10); + var workspaceId = UUID.randomUUID().toString(); + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var experiment = podamFactory.manufacturePojo(Experiment.class); + createAndAssert(experiment, apiKey, workspaceName); + + var streamRequest = ExperimentItemStreamRequest.builder().experimentName(experiment.name()).build(); + var expectedExperimentItems = List.of(); + var unexpectedExperimentItems1 = List.of(); + streamAndAssert(streamRequest, expectedExperimentItems, unexpectedExperimentItems1, apiKey, workspaceName); + } + + @Test + void streamByExperimentNameWithoutExperiments() { + var apiKey = UUID.randomUUID().toString(); + var workspaceName = RandomStringUtils.randomAlphanumeric(10); + var workspaceId = UUID.randomUUID().toString(); + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var streamRequest = ExperimentItemStreamRequest.builder() + .experimentName(RandomStringUtils.randomAlphanumeric(10)) + .build(); + var expectedExperimentItems = List.of(); + var unexpectedExperimentItems1 = List.of(); + streamAndAssert(streamRequest, expectedExperimentItems, unexpectedExperimentItems1, apiKey, workspaceName); + } + } + @Nested @TestInstance(TestInstance.Lifecycle.PER_CLASS) class CreateExperimentsItems { @@ -1632,13 +1738,11 @@ void createAndGet() { createAndAssert(request, API_KEY, TEST_WORKSPACE); - request.experimentItems() - .forEach(item -> ExperimentsResourceTest.this.getAndAssert(item, TEST_WORKSPACE, API_KEY)); + request.experimentItems().forEach(item -> getAndAssert(item, TEST_WORKSPACE, API_KEY)); } @Test void insertInvalidDatasetItemWorkspace() { - var workspaceName = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -1671,7 +1775,6 @@ void insertInvalidDatasetItemWorkspace() { @Test void insertInvalidExperimentWorkspace() { - var workspaceName = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); var workspaceId = UUID.randomUUID().toString(); @@ -1702,7 +1805,7 @@ void insertInvalidExperimentWorkspace() { } } - UUID createDatasetItem(String workspaceName, String apiKey) { + private UUID createDatasetItem(String workspaceName, String apiKey) { var item = podamFactory.manufacturePojo(DatasetItem.class); var batch = podamFactory.manufacturePojo(DatasetItemBatch.class).toBuilder() @@ -1750,7 +1853,6 @@ Stream insertInvalidId() { @ParameterizedTest @MethodSource void insertInvalidId(ExperimentItem experimentItem, String expectedErrorMessage) { - var request = ExperimentItemsBatch.builder() .experimentItems(Set.of(experimentItem)).build(); var expectedError = new com.comet.opik.api.error.ErrorMessage( @@ -1781,8 +1883,7 @@ void delete() { var createRequest = podamFactory.manufacturePojo(ExperimentItemsBatch.class).toBuilder() .build(); createAndAssert(createRequest, API_KEY, TEST_WORKSPACE); - createRequest.experimentItems() - .forEach(item -> ExperimentsResourceTest.this.getAndAssert(item, TEST_WORKSPACE, API_KEY)); + createRequest.experimentItems().forEach(item -> getAndAssert(item, TEST_WORKSPACE, API_KEY)); var ids = createRequest.experimentItems().stream().map(ExperimentItem::id).collect(Collectors.toSet()); var deleteRequest = ExperimentItemsDelete.builder().ids(ids).build(); @@ -1797,7 +1898,7 @@ void delete() { assertThat(actualResponse.hasEntity()).isFalse(); } - ids.forEach(id -> ExperimentsResourceTest.this.getAndAssertNotFound(id, API_KEY, TEST_WORKSPACE)); + ids.forEach(id -> getAndAssertNotFound(id, API_KEY, TEST_WORKSPACE)); } } @@ -1828,19 +1929,31 @@ private void getAndAssert(ExperimentItem expectedExperimentItem, String workspac assertThat(actualExperimentItem) .usingRecursiveComparison() - .ignoringFields(IGNORED_FIELDS) + .ignoringFields(ITEM_IGNORED_FIELDS) .isEqualTo(expectedExperimentItem); - assertThat(actualExperimentItem.input()).isNull(); - assertThat(actualExperimentItem.output()).isNull(); - assertThat(actualExperimentItem.feedbackScores()).isNull(); - assertThat(actualExperimentItem.createdAt()).isAfter(expectedExperimentItem.createdAt()); - assertThat(actualExperimentItem.lastUpdatedAt()).isAfter(expectedExperimentItem.lastUpdatedAt()); - assertThat(actualExperimentItem.createdBy()).isEqualTo(USER); - assertThat(actualExperimentItem.lastUpdatedBy()).isEqualTo(USER); + assertIgnoredFields(actualExperimentItem, expectedExperimentItem); } } + private void assertIgnoredFields( + List actualExperimentItems, List expectedExperimentItems) { + assertThat(actualExperimentItems).hasSameSizeAs(expectedExperimentItems); + for (int i = 0; i < actualExperimentItems.size(); i++) { + assertIgnoredFields(actualExperimentItems.get(i), expectedExperimentItems.get(i)); + } + } + + private void assertIgnoredFields(ExperimentItem actualExperimentItem, ExperimentItem expectedExperimentItem) { + assertThat(actualExperimentItem.input()).isNull(); + assertThat(actualExperimentItem.output()).isNull(); + assertThat(actualExperimentItem.feedbackScores()).isNull(); + assertThat(actualExperimentItem.createdAt()).isAfter(expectedExperimentItem.createdAt()); + assertThat(actualExperimentItem.lastUpdatedAt()).isAfter(expectedExperimentItem.lastUpdatedAt()); + assertThat(actualExperimentItem.createdBy()).isEqualTo(USER); + assertThat(actualExperimentItem.lastUpdatedBy()).isEqualTo(USER); + } + private void getAndAssertNotFound(UUID id, String apiKey, String workspaceName) { var expectedError = new ErrorMessage(404, "Not found experiment item with id '%s'".formatted(id)); try (var actualResponse = client.target(getExperimentItemsPath()) @@ -1858,6 +1971,49 @@ private void getAndAssertNotFound(UUID id, String apiKey, String workspaceName) } } + private void streamAndAssert( + ExperimentItemStreamRequest request, + List expectedExperimentItems, + List unexpectedExperimentItems, + String apiKey, + String workspaceName) { + try (var actualResponse = client.target(getExperimentItemsPath()) + .path("stream") + .request() + .accept(MediaType.APPLICATION_OCTET_STREAM) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(request))) { + + assertThat(actualResponse.getStatus()).isEqualTo(200); + + var actualExperimentItems = getStreamedItems(actualResponse); + + assertThat(actualExperimentItems) + .usingRecursiveFieldByFieldElementComparatorIgnoringFields(ITEM_IGNORED_FIELDS) + .containsExactlyElementsOf(expectedExperimentItems); + + assertIgnoredFields(actualExperimentItems, expectedExperimentItems); + + if (!unexpectedExperimentItems.isEmpty()) { + assertThat(actualExperimentItems) + .usingRecursiveFieldByFieldElementComparatorIgnoringFields(ITEM_IGNORED_FIELDS) + .doesNotContainAnyElementsOf(unexpectedExperimentItems); + } + } + } + + private List getStreamedItems(Response response) { + var items = new ArrayList(); + try (var inputStream = response.readEntity(CHUNKED_INPUT_STRING_GENERIC_TYPE)) { + String stringItem; + while ((stringItem = inputStream.read()) != null) { + items.add(JsonUtils.readValue(stringItem, EXPERIMENT_ITEM_TYPE_REFERENCE)); + } + } + return items; + } + private String getExperimentsPath() { return URL_TEMPLATE.formatted(baseURI); } diff --git a/apps/opik-frontend/package-lock.json b/apps/opik-frontend/package-lock.json index 7f75512e7c..f0edd9f186 100644 --- a/apps/opik-frontend/package-lock.json +++ b/apps/opik-frontend/package-lock.json @@ -31,6 +31,7 @@ "@tanstack/react-query": "^5.45.0", "@tanstack/react-router": "^1.36.3", "@tanstack/react-table": "^8.17.3", + "@types/diff": "^5.2.2", "@types/md5": "^2.3.5", "@types/segment-analytics": "^0.0.38", "@uiw/react-codemirror": "^4.23.0", @@ -41,6 +42,7 @@ "codemirror": "^6.0.1", "date-fns": "^3.6.0", "dayjs": "^1.11.11", + "diff": "^7.0.0", "flattie": "^1.1.1", "js-yaml": "^4.1.0", "lodash": "^4.17.21", @@ -5698,6 +5700,11 @@ "@babel/types": "^7.20.7" } }, + "node_modules/@types/diff": { + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/@types/diff/-/diff-5.2.2.tgz", + "integrity": "sha512-qVqLpd49rmJA2nZzLVsmfS/aiiBpfVE95dHhPVwG0NmSBAt+riPxnj53wq2oBq5m4Q2RF1IWFEUpnZTgrQZfEQ==" + }, "node_modules/@types/estree": { "version": "1.0.5", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", @@ -7311,6 +7318,14 @@ "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", "integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==" }, + "node_modules/diff": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/diff/-/diff-7.0.0.tgz", + "integrity": "sha512-PJWHUb1RFevKCwaFA9RlG5tCd+FO5iRh9A8HEtkmBH2Li03iJriB6m6JIN4rGz3K3JLawI7/veA1xzRKP6ISBw==", + "engines": { + "node": ">=0.3.1" + } + }, "node_modules/diff-sequences": { "version": "29.6.3", "resolved": "https://registry.npmjs.org/diff-sequences/-/diff-sequences-29.6.3.tgz", diff --git a/apps/opik-frontend/package.json b/apps/opik-frontend/package.json index 6f17f8a3e1..e098c10316 100644 --- a/apps/opik-frontend/package.json +++ b/apps/opik-frontend/package.json @@ -48,6 +48,7 @@ "@tanstack/react-query": "^5.45.0", "@tanstack/react-router": "^1.36.3", "@tanstack/react-table": "^8.17.3", + "@types/diff": "^5.2.2", "@types/md5": "^2.3.5", "@types/segment-analytics": "^0.0.38", "@uiw/react-codemirror": "^4.23.0", @@ -58,6 +59,7 @@ "codemirror": "^6.0.1", "date-fns": "^3.6.0", "dayjs": "^1.11.11", + "diff": "^7.0.0", "flattie": "^1.1.1", "js-yaml": "^4.1.0", "lodash": "^4.17.21", diff --git a/apps/opik-frontend/src/components/pages/CompareExperimentsPage/CompareExperimentsHeader.tsx b/apps/opik-frontend/src/components/pages/CompareExperimentsPage/CompareExperimentsHeader.tsx index ec3d014c30..89e5b9860a 100644 --- a/apps/opik-frontend/src/components/pages/CompareExperimentsPage/CompareExperimentsHeader.tsx +++ b/apps/opik-frontend/src/components/pages/CompareExperimentsPage/CompareExperimentsHeader.tsx @@ -3,15 +3,20 @@ import { FlaskConical, X } from "lucide-react"; import { HeaderContext } from "@tanstack/react-table"; import { Button } from "@/components/ui/button"; -import { ExperimentsCompare } from "@/types/datasets"; +import { Experiment, ExperimentsCompare } from "@/types/datasets"; import { JsonParam, useQueryParam } from "use-query-params"; -import useExperimentById from "@/api/datasets/useExperimentById"; + +type CustomMeta = { + experiment?: Experiment; +}; const CompareExperimentsHeader: React.FunctionComponent< HeaderContext -> = ({ table, header }) => { - const experimentId = header?.id; - const hasData = table.getRowCount() > 0; +> = (context) => { + const { custom } = context.column.columnDef.meta ?? {}; + const { experiment } = (custom ?? {}) as CustomMeta; + const experimentId = context.header?.id; + const hasData = context.table.getRowCount() > 0; const [experimentIds, setExperimentsIds] = useQueryParam( "experiments", JsonParam, @@ -20,16 +25,7 @@ const CompareExperimentsHeader: React.FunctionComponent< }, ); - const { data } = useExperimentById( - { - experimentId, - }, - { - refetchOnMount: false, - }, - ); - - const name = data?.name || experimentId; + const name = experiment?.name || experimentId; return (
{ - + > = (context) => { + const { custom } = context.column.columnDef.meta ?? {}; + const { onlyDiff } = (custom ?? {}) as CustomMeta; const experimentId = context.column?.id; const compareConfig = context.row.original; const data = compareConfig.data[experimentId]; + const baseData = compareConfig.data[compareConfig.base]; + + const renderContent = () => { + if (isUndefined(data)) { + return No value; + } - if (data === undefined) { - return null; - } + return ( +
+ {showDiffView ? ( + + ) : ( + toString(data) + )} +
+ ); + }; + + const showDiffView = + onlyDiff && + Object.values(compareConfig.data).length >= 2 && + experimentId !== compareConfig.base; return ( -
- {String(data)} -
+ {renderContent()}
); }; diff --git a/apps/opik-frontend/src/components/pages/CompareExperimentsPage/ConfigurationTab/ConfigurationTab.tsx b/apps/opik-frontend/src/components/pages/CompareExperimentsPage/ConfigurationTab/ConfigurationTab.tsx index dca0c89c96..9d41b7b8c5 100644 --- a/apps/opik-frontend/src/components/pages/CompareExperimentsPage/ConfigurationTab/ConfigurationTab.tsx +++ b/apps/opik-frontend/src/components/pages/CompareExperimentsPage/ConfigurationTab/ConfigurationTab.tsx @@ -4,6 +4,7 @@ import useLocalStorageState from "use-local-storage-state"; import isObject from "lodash/isObject"; import uniq from "lodash/uniq"; import toLower from "lodash/toLower"; +import find from "lodash/find"; import { flattie } from "flattie"; import { COLUMN_TYPE, ColumnData } from "@/types/shared"; @@ -79,6 +80,12 @@ const ConfigurationTab: React.FunctionComponent = ({ accessorKey: id, header: CompareExperimentsHeader as never, cell: CompareConfigCell as never, + meta: { + custom: { + onlyDiff, + experiment: find(experiments, (e) => e.id === id), + }, + }, size, minSize: 120, }); @@ -93,7 +100,7 @@ const ConfigurationTab: React.FunctionComponent = ({ }); return retVal; - }, [columnsWidth, experimentsIds]); + }, [columnsWidth, experimentsIds, onlyDiff, experiments]); const flattenExperimentMetadataMap = useMemo(() => { return experiments.reduce>>( @@ -164,7 +171,7 @@ const ConfigurationTab: React.FunctionComponent = ({ } return ( -
+
= ({ experimentsIds = [], + experiments, }) => { const datasetId = useDatasetIdFromCompareExperimentsURL(); const workspaceName = useAppStore((state) => state.activeWorkspaceName); @@ -177,6 +179,7 @@ const ExperimentItemsTab: React.FunctionComponent = ({ meta: { custom: { openTrace: setTraceId, + experiment: find(experiments, (e) => e.id === id), }, }, size, @@ -193,7 +196,14 @@ const ExperimentItemsTab: React.FunctionComponent = ({ }); return retVal; - }, [columnsWidth, selectedColumns, columnsOrder, experimentsIds, setTraceId]); + }, [ + columnsWidth, + selectedColumns, + columnsOrder, + experimentsIds, + setTraceId, + experiments, + ]); const { data, isPending } = useCompareExperimentsList( { diff --git a/apps/opik-frontend/src/components/shared/CodeDiff/TextDiff.tsx b/apps/opik-frontend/src/components/shared/CodeDiff/TextDiff.tsx new file mode 100644 index 0000000000..1c9da092ac --- /dev/null +++ b/apps/opik-frontend/src/components/shared/CodeDiff/TextDiff.tsx @@ -0,0 +1,35 @@ +import React, { useMemo } from "react"; +import { diffLines } from "diff"; +import { cn } from "@/lib/utils"; + +type CodeDiffProps = { + content1: string; + content2: string; +}; + +const TextDiff: React.FunctionComponent = ({ + content1, + content2, +}) => { + return useMemo(() => { + const changes = diffLines(content1, content2); + + return ( +
+ {changes.map((c, index) => ( +
+ {c.value} +
+ ))} +
+ ); + }, [content1, content2]); +}; + +export default TextDiff; diff --git a/apps/opik-frontend/src/components/ui/table.tsx b/apps/opik-frontend/src/components/ui/table.tsx index 87924465ab..ec69ae2bec 100644 --- a/apps/opik-frontend/src/components/ui/table.tsx +++ b/apps/opik-frontend/src/components/ui/table.tsx @@ -89,7 +89,7 @@ const TableCell = React.forwardRef< context.measureText(v).width); }; + +export const toString = (value?: string | number | boolean | null) => + isUndefined(value) ? "" : String(value);