From 28b5ad0ff6c6ea4a8b5c543d3c07c4044093899d Mon Sep 17 00:00:00 2001 From: Andres Cruz Date: Wed, 18 Sep 2024 13:22:41 +0200 Subject: [PATCH] OPIK-76 Add stream experiment items endpoint --- .../opik/api/ExperimentItemStreamRequest.java | 25 ++ .../v1/priv/ExperimentsResource.java | 22 ++ .../com/comet/opik/domain/ExperimentDAO.java | 28 ++ .../comet/opik/domain/ExperimentItemDAO.java | 36 +++ .../opik/domain/ExperimentItemService.java | 66 ++++- .../comet/opik/domain/ExperimentService.java | 8 + .../v1/priv/ExperimentsResourceTest.java | 240 +++++++++++++----- 7 files changed, 363 insertions(+), 62 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemStreamRequest.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemStreamRequest.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemStreamRequest.java new file mode 100644 index 0000000000..97fc11bcc2 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ExperimentItemStreamRequest.java @@ -0,0 +1,25 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; + +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record ExperimentItemStreamRequest( + @NotBlank String experimentName, + @Min(1) @Max(2000) Integer limit, + UUID lastRetrievedId) { + + @Override + public Integer limit() { + return limit == null ? 500 : limit; + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java index 816535f380..f28ebf76a4 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ExperimentsResource.java @@ -3,6 +3,7 @@ import com.codahale.metrics.annotation.Timed; import com.comet.opik.api.Experiment; import com.comet.opik.api.ExperimentItem; +import com.comet.opik.api.ExperimentItemStreamRequest; import com.comet.opik.api.ExperimentItemsBatch; import com.comet.opik.api.ExperimentItemsDelete; import com.comet.opik.api.ExperimentSearchCriteria; @@ -12,9 +13,11 @@ import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.utils.AsyncUtils; import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.JsonNode; import io.dropwizard.jersey.errors.ErrorMessage; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.media.ArraySchema; import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; @@ -40,6 +43,7 @@ import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.glassfish.jersey.server.ChunkedOutput; import java.util.Set; import java.util.UUID; @@ -147,6 +151,24 @@ public Response getExperimentItem(@PathParam("id") UUID id) { return Response.ok().entity(experimentItem).build(); } + @POST + @Path("/items/stream") + @Produces(MediaType.APPLICATION_OCTET_STREAM) + @Operation(operationId = "streamExperimentItems", summary = "Stream experiment items", description = "Stream experiment items", responses = { + @ApiResponse(responseCode = "200", description = "Experiment items stream or error during process", content = @Content(array = @ArraySchema(schema = @Schema(anyOf = { + ExperimentItem.class, + ErrorMessage.class + }), maxItems = 2000))) + }) + public ChunkedOutput streamExperimentItems( + @RequestBody(content = @Content(schema = @Schema(implementation = ExperimentItemStreamRequest.class))) @NotNull @Valid ExperimentItemStreamRequest request) { + var workspaceId = requestContext.get().getWorkspaceId(); + log.info("Streaming experiment items by '{}', workspaceId '{}'", request, workspaceId); + var stream = experimentItemService.getExperimentItemsStream(request); + log.info("Streamed dataset items by '{}', workspaceId '{}'", request, workspaceId); + return stream; + } + @POST @Path("/items") @Operation(operationId = "createExperimentItems", summary = "Create experiment items", description = "Create experiment items", responses = { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentDAO.java index cb5e7558e0..0b6ec3aced 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentDAO.java @@ -5,6 +5,7 @@ import com.comet.opik.api.FeedbackScoreAverage; import com.comet.opik.utils.JsonUtils; import com.fasterxml.jackson.databind.JsonNode; +import com.google.common.base.Preconditions; import io.r2dbc.spi.Connection; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.Result; @@ -16,6 +17,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.reactivestreams.Publisher; import org.stringtemplate.v4.ST; import reactor.core.publisher.Flux; @@ -338,6 +340,18 @@ SELECT count(id) as count ; """; + private static final String FIND_BY_NAME = """ + SELECT + *, + null AS feedback_scores, + null AS trace_count + FROM experiments + WHERE workspace_id = :workspace_id + AND ilike(name, CONCAT('%', :name, '%')) + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + """; + private static final String FIND_EXPERIMENT_AND_WORKSPACE_BY_DATASET_IDS = """ SELECT id, workspace_id @@ -485,6 +499,20 @@ private void bindSearchCriteria(Statement statement, ExperimentSearchCriteria cr } } + Flux findByName(String name) { + Preconditions.checkArgument(StringUtils.isNotBlank(name), "Argument 'name' must not be blank"); + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> findByName(name, connection)) + .flatMap(this::mapToDto); + } + + private Publisher findByName(String name, Connection connection) { + log.info("Finding experiment by name '{}'", name); + var statement = connection.createStatement(FIND_BY_NAME) + .bind("name", name); + return makeFluxContextAware(bindWorkspaceIdToFlux(statement)); + } + public Flux getExperimentWorkspaces(@NonNull Set experimentIds) { if (experimentIds.isEmpty()) { return Flux.empty(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemDAO.java index 08594f56af..35f5b8b0bf 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemDAO.java @@ -83,6 +83,19 @@ INSERT INTO experiment_items ( ; """; + private static final String STREAM = """ + SELECT + * + FROM experiment_items + WHERE workspace_id = :workspace_id + AND experiment_id IN :experiment_ids + AND id \\< :lastRetrievedId + ORDER BY experiment_id DESC, id DESC, last_updated_at DESC + LIMIT 1 BY id + LIMIT :limit + ; + """; + private static final String DELETE = """ DELETE FROM experiment_items WHERE id IN :ids @@ -201,6 +214,29 @@ private Publisher get(UUID id, Connection connection) { return makeFluxContextAware(bindWorkspaceIdToFlux(statement)); } + public Flux getItems(Set experimentIds, int limit, UUID lastRetrievedId) { + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> getItems(experimentIds, limit, lastRetrievedId, connection)) + .flatMap(this::mapToExperimentItem); + } + + private Publisher getItems( + Set experimentIds, int limit, UUID lastRetrievedId, Connection connection) { + log.info("Streaming experiment items by experimentIds count '{}', limit '{}', lastRetrievedId '{}'", + experimentIds.size(), limit, lastRetrievedId); + var template = new ST(STREAM); + if (lastRetrievedId != null) { + template.add("lastRetrievedId", lastRetrievedId); + } + var statement = connection.createStatement(template.render()) + .bind("experiment_ids", experimentIds.toArray(UUID[]::new)) + .bind("limit", limit); + if (lastRetrievedId != null) { + statement.bind("lastRetrievedId", lastRetrievedId); + } + return makeFluxContextAware(bindWorkspaceIdToFlux(statement)); + } + public Mono delete(Set ids) { Preconditions.checkArgument(CollectionUtils.isNotEmpty(ids), "Argument 'ids' must not be empty"); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemService.java index 39672965dc..7f1e568699 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentItemService.java @@ -1,9 +1,15 @@ package com.comet.opik.domain; +import com.comet.opik.api.Experiment; import com.comet.opik.api.ExperimentItem; +import com.comet.opik.api.ExperimentItemStreamRequest; import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.utils.JsonUtils; +import com.fasterxml.jackson.databind.JsonNode; import com.google.common.base.Preconditions; +import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.inject.Inject; +import jakarta.inject.Provider; import jakarta.inject.Singleton; import jakarta.ws.rs.ClientErrorException; import jakarta.ws.rs.NotFoundException; @@ -12,10 +18,16 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; +import org.glassfish.jersey.server.ChunkedOutput; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Set; import java.util.UUID; +import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; @Singleton @@ -26,6 +38,7 @@ public class ExperimentItemService { private final @NonNull ExperimentItemDAO experimentItemDAO; private final @NonNull ExperimentService experimentService; private final @NonNull DatasetItemDAO datasetItemDAO; + private final @NonNull Provider requestContext; public Mono create(Set experimentItems) { Preconditions.checkArgument(CollectionUtils.isNotEmpty(experimentItems), @@ -116,6 +129,58 @@ private NotFoundException newNotFoundException(UUID id) { return new NotFoundException(message); } + public ChunkedOutput getExperimentItemsStream(@NonNull ExperimentItemStreamRequest request) { + var outputStream = new ChunkedOutput(JsonNode.class, "\r\n"); + var workspaceId = requestContext.get().getWorkspaceId(); + var userName = requestContext.get().getUserName(); + var workspaceName = requestContext.get().getWorkspaceName(); + log.info("Getting experiment items stream by '{}', workspaceId '{}'", request, workspaceId); + Schedulers.boundedElastic() + .schedule(() -> Mono + .fromCallable(() -> experimentService.findByName(request.experimentName())) + .subscribeOn(Schedulers.boundedElastic()) + .flatMap(experiments -> experiments.map(Experiment::id).collect(Collectors.toUnmodifiableSet())) + .flatMapMany( + experimentIds -> experimentItemDAO.getItems( + experimentIds, request.limit(), request.lastRetrievedId())) + .doOnNext(item -> sendItem(item, outputStream)) + .onErrorResume(throwable -> handleError(throwable, outputStream)) + .doFinally(signalType -> close(outputStream)) + .contextWrite(ctx -> ctx.put(RequestContext.USER_NAME, userName) + .put(RequestContext.WORKSPACE_NAME, workspaceName) + .put(RequestContext.WORKSPACE_ID, workspaceId)) + .subscribe()); + log.info("Got experiment items stream by '{}', workspaceId '{}'", request, workspaceId); + return outputStream; + } + + private void sendItem(ExperimentItem item, ChunkedOutput outputStream) { + try { + outputStream.write(JsonUtils.readTree(item)); + } catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + private Flux handleError(Throwable throwable, ChunkedOutput outputStream) { + if (throwable instanceof TimeoutException) { + try { + outputStream.write(JsonUtils.readTree(new ErrorMessage(500, "Streaming operation timed out"))); + } catch (IOException ioException) { + log.error("Failed to stream error to client", ioException); + } + } + return Flux.error(throwable); + } + + private void close(ChunkedOutput outputStream) { + try { + outputStream.close(); + } catch (IOException exception) { + log.error("Error while closing experiment items stream", exception); + } + } + public Mono delete(@NonNull Set ids) { Preconditions.checkArgument(CollectionUtils.isNotEmpty(ids), "Argument 'ids' must not be empty"); @@ -123,5 +188,4 @@ public Mono delete(@NonNull Set ids) { log.info("Deleting experiment items, count '{}'", ids.size()); return experimentItemDAO.delete(ids).then(); } - } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentService.java index 6576742993..3e06bac95b 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ExperimentService.java @@ -6,6 +6,7 @@ import com.comet.opik.api.ExperimentSearchCriteria; import com.comet.opik.api.error.EntityAlreadyExistsException; import com.comet.opik.infrastructure.auth.RequestContext; +import com.google.common.base.Preconditions; import jakarta.inject.Inject; import jakarta.inject.Singleton; import jakarta.ws.rs.ClientErrorException; @@ -15,6 +16,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -60,6 +62,12 @@ public Mono find( })); } + public Flux findByName(String name) { + Preconditions.checkArgument(StringUtils.isNotBlank(name), "Argument 'name' must not be blank"); + log.info("Finding experiments by name '{}'", name); + return experimentDAO.findByName(name); + } + public Mono getById(@NonNull UUID id) { log.info("Getting experiment by id '{}'", id); return experimentDAO.getById(id) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java index 3b4e852e8b..e55721d51d 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java @@ -4,6 +4,7 @@ import com.comet.opik.api.DatasetItemBatch; import com.comet.opik.api.Experiment; import com.comet.opik.api.ExperimentItem; +import com.comet.opik.api.ExperimentItemStreamRequest; import com.comet.opik.api.ExperimentItemsBatch; import com.comet.opik.api.ExperimentItemsDelete; import com.comet.opik.api.FeedbackScore; @@ -22,16 +23,22 @@ import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.domain.FeedbackScoreMapper; import com.comet.opik.podam.PodamFactoryUtils; +import com.comet.opik.utils.JsonUtils; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.uuid.Generators; import com.fasterxml.uuid.impl.TimeBasedEpochGenerator; import com.github.tomakehurst.wiremock.client.WireMock; import com.redis.testcontainers.RedisContainer; import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.core.GenericType; import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; import org.apache.commons.lang3.RandomStringUtils; import org.apache.commons.lang3.StringUtils; import org.assertj.core.api.recursive.comparison.RecursiveComparisonConfiguration; +import org.glassfish.jersey.client.ChunkedInput; import org.jdbi.v3.core.Jdbi; import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterAll; @@ -57,7 +64,9 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.sql.SQLException; +import java.util.ArrayList; import java.util.Collection; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Set; @@ -95,13 +104,19 @@ class ExperimentsResourceTest { private static final String[] EXPERIMENT_IGNORED_FIELDS = new String[]{ "id", "datasetId", "name", "feedbackScores", "traceCount", "createdAt", "lastUpdatedAt", "createdBy", "lastUpdatedBy"}; - public static final String[] IGNORED_FIELDS = {"input", "output", "feedbackScores", "createdAt", "lastUpdatedAt", - "createdBy", "lastUpdatedBy"}; + public static final String[] ITEM_IGNORED_FIELDS = {"input", "output", "feedbackScores", "createdAt", + "lastUpdatedAt", "createdBy", "lastUpdatedBy"}; private static final String WORKSPACE_ID = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); private static final String TEST_WORKSPACE = UUID.randomUUID().toString(); + private static final GenericType> CHUNKED_INPUT_STRING_GENERIC_TYPE = new GenericType<>() { + }; + + private static final TypeReference EXPERIMENT_ITEM_TYPE_REFERENCE = new TypeReference<>() { + }; + private static final TimeBasedEpochGenerator GENERATOR = Generators.timeBasedEpochGenerator(); private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); @@ -153,10 +168,10 @@ private static void mockTargetWorkspace(String apiKey, String workspaceName, Str AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); } - private static void mockSessionCookieTargetWorkspace(String sessionToken, String workspaceName, - String workspaceId) { - AuthTestUtils.mockSessionCookieTargetWorkspace(wireMock.server(), sessionToken, workspaceName, workspaceId, - USER); + private static void mockSessionCookieTargetWorkspace( + String sessionToken, String workspaceName, String workspaceId) { + AuthTestUtils.mockSessionCookieTargetWorkspace( + wireMock.server(), sessionToken, workspaceName, workspaceId, USER); } @AfterAll @@ -181,7 +196,6 @@ Stream credentials() { @BeforeEach void setUp() { - wireMock.server().stubFor( post(urlPathEqualTo("/opik/auth")) .withHeader(HttpHeaders.AUTHORIZATION, equalTo(fakeApikey)) @@ -294,8 +308,7 @@ void deleteExperimentItems__whenApiKeyIsPresent__thenReturnProperResponse(String createAndAssert(createRequest, okApikey, workspaceName); - createRequest.experimentItems() - .forEach(item -> ExperimentsResourceTest.this.getAndAssert(item, workspaceName, okApikey)); + createRequest.experimentItems().forEach(item -> getAndAssert(item, workspaceName, okApikey)); var ids = createRequest.experimentItems().stream().map(ExperimentItem::id).collect(Collectors.toSet()); var deleteRequest = ExperimentItemsDelete.builder().ids(ids).build(); @@ -346,7 +359,6 @@ void createExperimentItems__whenApiKeyIsPresent__thenReturnProperResponse(String @ParameterizedTest @MethodSource("credentials") void getExperimentItemById__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean success) { - var workspaceName = UUID.randomUUID().toString(); var expectedExperimentItem = podamFactory.manufacturePojo(ExperimentItem.class); var id = expectedExperimentItem.id(); @@ -408,8 +420,8 @@ void setUp() { @ParameterizedTest @MethodSource("credentials") - void getById__whenSessionTokenIsPresent__thenReturnProperResponse(String currentSessionToken, boolean success, - String workspaceName) { + void getById__whenSessionTokenIsPresent__thenReturnProperResponse( + String currentSessionToken, boolean success, String workspaceName) { var expectedExperiment = podamFactory.manufacturePojo(Experiment.class); mockTargetWorkspace(API_KEY, workspaceName, WORKSPACE_ID); @@ -433,8 +445,8 @@ void getById__whenSessionTokenIsPresent__thenReturnProperResponse(String current @ParameterizedTest @MethodSource("credentials") - void create__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, boolean success, - String workspaceName) { + void create__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { var expectedExperiment = podamFactory.manufacturePojo(Experiment.class); mockTargetWorkspace(API_KEY, sessionToken, WORKSPACE_ID); @@ -456,8 +468,8 @@ void create__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionT @ParameterizedTest @MethodSource("credentials") - void find__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, boolean success, - String workspaceName) { + void find__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { var workspaceId = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -495,15 +507,14 @@ void find__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionTok @ParameterizedTest @MethodSource("credentials") - void deleteExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, - boolean success, String workspaceName) { + void deleteExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { mockTargetWorkspace(API_KEY, workspaceName, WORKSPACE_ID); var createRequest = podamFactory.manufacturePojo(ExperimentItemsBatch.class); createAndAssert(createRequest, API_KEY, workspaceName); - createRequest.experimentItems() - .forEach(item -> ExperimentsResourceTest.this.getAndAssert(item, workspaceName, API_KEY)); + createRequest.experimentItems().forEach(item -> getAndAssert(item, workspaceName, API_KEY)); var ids = createRequest.experimentItems().stream().map(ExperimentItem::id).collect(Collectors.toSet()); var deleteRequest = ExperimentItemsDelete.builder().ids(ids).build(); @@ -527,8 +538,8 @@ void deleteExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse( @ParameterizedTest @MethodSource("credentials") - void createExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, - boolean success, String workspaceName) { + void createExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { var request = podamFactory.manufacturePojo(ExperimentItemsBatch.class); @@ -550,8 +561,8 @@ void createExperimentItems__whenSessionTokenIsPresent__thenReturnProperResponse( @ParameterizedTest @MethodSource("credentials") - void getExperimentItemById__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, - boolean success, String workspaceName) { + void getExperimentItemById__whenSessionTokenIsPresent__thenReturnProperResponse( + String sessionToken, boolean success, String workspaceName) { mockTargetWorkspace(API_KEY, workspaceName, WORKSPACE_ID); @@ -602,12 +613,11 @@ void findByDatasetId() { .datasetName(datasetName) .build()) .toList(); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, - apiKey, workspaceName)); + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var unexpectedExperiments = List.of(podamFactory.manufacturePojo(Experiment.class)); - unexpectedExperiments.forEach(expectedExperiment -> ExperimentsResourceTest.this - .createAndAssert(expectedExperiment, apiKey, workspaceName)); + unexpectedExperiments + .forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var pageSize = experiments.size() - 2; var datasetId = getAndAssert(experiments.getFirst().id(), experiments.getFirst(), workspaceName, apiKey) @@ -651,12 +661,12 @@ void findByName(String name, String nameQueryParam) { .name(name) .build()) .toList(); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var unexpectedExperiments = List.of(podamFactory.manufacturePojo(Experiment.class)); - unexpectedExperiments.forEach(expectedExperiment -> ExperimentsResourceTest.this - .createAndAssert(expectedExperiment, apiKey, workspaceName)); + unexpectedExperiments + .forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var pageSize = experiments.size() - 2; UUID datasetId = null; @@ -672,7 +682,6 @@ void findByName(String name, String nameQueryParam) { @Test void findByDatasetIdAndName() { - var workspaceName = UUID.randomUUID().toString(); var workspaceId = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -690,12 +699,12 @@ void findByDatasetIdAndName() { .metadata(null) .build()) .toList(); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var unexpectedExperiments = List.of(podamFactory.manufacturePojo(Experiment.class)); - unexpectedExperiments.forEach(expectedExperiment -> ExperimentsResourceTest.this - .createAndAssert(expectedExperiment, apiKey, workspaceName)); + unexpectedExperiments + .forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var pageSize = experiments.size() - 2; var datasetId = getAndAssert(experiments.getFirst().id(), experiments.getFirst(), workspaceName, apiKey) @@ -720,7 +729,7 @@ void findAll() { var experiments = PodamFactoryUtils.manufacturePojoList(podamFactory, Experiment.class); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var page = 1; @@ -760,8 +769,7 @@ void findAllAndCalculateFeedbackAvg() { .build()) .toList(); - experiments.forEach(expectedExperiment -> ExperimentsResourceTest.this.createAndAssert(expectedExperiment, - apiKey, workspaceName)); + experiments.forEach(expectedExperiment -> createAndAssert(expectedExperiment, apiKey, workspaceName)); var noScoreExperiment = podamFactory.manufacturePojo(Experiment.class); createAndAssert(noScoreExperiment, apiKey, workspaceName); @@ -1007,8 +1015,8 @@ private void deleteTrace(UUID id, String apiKey, String workspaceName) { } } - private @NotNull Map> getExpectedScoresPerExperiment(List experiments, - List experimentItems) { + private @NotNull Map> getExpectedScoresPerExperiment( + List experiments, List experimentItems) { return experiments.stream() .map(experiment -> Map.entry(experiment.id(), experimentItems .stream() @@ -1328,7 +1336,6 @@ void getNotFound() { @Test void createAndGetWithDeletedTrace() { - var workspaceName = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); var workspaceId = UUID.randomUUID().toString(); @@ -1426,7 +1433,6 @@ void createAndGetWithDeletedTrace() { } private ExperimentItemsBatch addRandomExperiments(List experimentItems) { - // When storing the experiment items in batch, adding some more unrelated random ones var experimentItemsBatch = podamFactory.manufacturePojo(ExperimentItemsBatch.class); experimentItemsBatch = experimentItemsBatch.toBuilder() @@ -1440,13 +1446,11 @@ private ExperimentItemsBatch addRandomExperiments(List experimen private Map getScoresMap(Experiment experiment) { List feedbackScores = experiment.feedbackScores(); - if (feedbackScores != null) { return feedbackScores .stream() .collect(Collectors.toMap(FeedbackScoreAverage::name, FeedbackScoreAverage::value)); } - return null; } @@ -1632,13 +1636,77 @@ void createAndGet() { createAndAssert(request, API_KEY, TEST_WORKSPACE); - request.experimentItems() - .forEach(item -> ExperimentsResourceTest.this.getAndAssert(item, TEST_WORKSPACE, API_KEY)); + request.experimentItems().forEach(item -> getAndAssert(item, TEST_WORKSPACE, API_KEY)); } @Test - void insertInvalidDatasetItemWorkspace() { + void createAndStream() { + var apiKey = UUID.randomUUID().toString(); + var workspaceName = RandomStringUtils.randomAlphanumeric(10); + var workspaceId = UUID.randomUUID().toString(); + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var experiment1 = podamFactory.manufacturePojo(Experiment.class); + createAndAssert(experiment1, apiKey, workspaceName); + + var experiment2 = podamFactory.manufacturePojo(Experiment.class).toBuilder() + .name(experiment1.name().substring(1, experiment1.name().length() - 1).toLowerCase()) + .build(); + createAndAssert(experiment2, apiKey, workspaceName); + + var experiment3 = podamFactory.manufacturePojo(Experiment.class); + createAndAssert(experiment3, apiKey, workspaceName); + + var experimentItems1 = PodamFactoryUtils.manufacturePojoList(podamFactory, ExperimentItem.class).stream() + .map(experimentItem -> experimentItem.toBuilder().experimentId(experiment1.id()).build()) + .collect(Collectors.toUnmodifiableSet()); + var createRequest1 = ExperimentItemsBatch.builder().experimentItems(experimentItems1).build(); + createAndAssert(createRequest1, apiKey, workspaceName); + + var experimentItems2 = PodamFactoryUtils.manufacturePojoList(podamFactory, ExperimentItem.class).stream() + .map(experimentItem -> experimentItem.toBuilder().experimentId(experiment2.id()).build()) + .collect(Collectors.toUnmodifiableSet()); + var createRequest2 = ExperimentItemsBatch.builder().experimentItems(experimentItems2).build(); + createAndAssert(createRequest2, apiKey, workspaceName); + + var experimentItems3 = PodamFactoryUtils.manufacturePojoList(podamFactory, ExperimentItem.class).stream() + .map(experimentItem -> experimentItem.toBuilder().experimentId(experiment3.id()).build()) + .collect(Collectors.toUnmodifiableSet()); + var createRequest3 = ExperimentItemsBatch.builder().experimentItems(experimentItems3).build(); + createAndAssert(createRequest3, apiKey, workspaceName); + + var size = experimentItems1.size() + experimentItems2.size(); + var limit = size / 2; + + var expectedExperimentItems = Stream.concat(experimentItems1.stream(), experimentItems2.stream()) + .sorted(Comparator.comparing(ExperimentItem::experimentId).thenComparing(ExperimentItem::id)) + .toList() + .reversed(); + + var expectedExperimentItems1 = expectedExperimentItems.subList(0, limit); + var expectedExperimentItems2 = expectedExperimentItems.subList(limit, size); + + var streamRequest1 = ExperimentItemStreamRequest.builder() + .experimentName(experiment2.name()) + .limit(limit) + .build(); + var unexpectedExperimentItems1 = Stream.concat(expectedExperimentItems2.stream(), experimentItems3.stream()) + .toList(); + streamAndAssert(streamRequest1, expectedExperimentItems1, unexpectedExperimentItems1, apiKey, + workspaceName); + var streamRequest2 = ExperimentItemStreamRequest.builder() + .experimentName(experiment2.name()) + .lastRetrievedId(expectedExperimentItems1.getLast().id()) + .build(); + var unexpectedExperimentItems2 = Stream.concat(expectedExperimentItems1.stream(), experimentItems3.stream()) + .toList(); + streamAndAssert(streamRequest2, expectedExperimentItems2, unexpectedExperimentItems2, apiKey, + workspaceName); + } + + @Test + void insertInvalidDatasetItemWorkspace() { var workspaceName = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -1671,7 +1739,6 @@ void insertInvalidDatasetItemWorkspace() { @Test void insertInvalidExperimentWorkspace() { - var workspaceName = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); var workspaceId = UUID.randomUUID().toString(); @@ -1702,7 +1769,7 @@ void insertInvalidExperimentWorkspace() { } } - UUID createDatasetItem(String workspaceName, String apiKey) { + private UUID createDatasetItem(String workspaceName, String apiKey) { var item = podamFactory.manufacturePojo(DatasetItem.class); var batch = podamFactory.manufacturePojo(DatasetItemBatch.class).toBuilder() @@ -1750,7 +1817,6 @@ Stream insertInvalidId() { @ParameterizedTest @MethodSource void insertInvalidId(ExperimentItem experimentItem, String expectedErrorMessage) { - var request = ExperimentItemsBatch.builder() .experimentItems(Set.of(experimentItem)).build(); var expectedError = new com.comet.opik.api.error.ErrorMessage( @@ -1781,8 +1847,7 @@ void delete() { var createRequest = podamFactory.manufacturePojo(ExperimentItemsBatch.class).toBuilder() .build(); createAndAssert(createRequest, API_KEY, TEST_WORKSPACE); - createRequest.experimentItems() - .forEach(item -> ExperimentsResourceTest.this.getAndAssert(item, TEST_WORKSPACE, API_KEY)); + createRequest.experimentItems().forEach(item -> getAndAssert(item, TEST_WORKSPACE, API_KEY)); var ids = createRequest.experimentItems().stream().map(ExperimentItem::id).collect(Collectors.toSet()); var deleteRequest = ExperimentItemsDelete.builder().ids(ids).build(); @@ -1797,7 +1862,7 @@ void delete() { assertThat(actualResponse.hasEntity()).isFalse(); } - ids.forEach(id -> ExperimentsResourceTest.this.getAndAssertNotFound(id, API_KEY, TEST_WORKSPACE)); + ids.forEach(id -> getAndAssertNotFound(id, API_KEY, TEST_WORKSPACE)); } } @@ -1828,19 +1893,31 @@ private void getAndAssert(ExperimentItem expectedExperimentItem, String workspac assertThat(actualExperimentItem) .usingRecursiveComparison() - .ignoringFields(IGNORED_FIELDS) + .ignoringFields(ITEM_IGNORED_FIELDS) .isEqualTo(expectedExperimentItem); - assertThat(actualExperimentItem.input()).isNull(); - assertThat(actualExperimentItem.output()).isNull(); - assertThat(actualExperimentItem.feedbackScores()).isNull(); - assertThat(actualExperimentItem.createdAt()).isAfter(expectedExperimentItem.createdAt()); - assertThat(actualExperimentItem.lastUpdatedAt()).isAfter(expectedExperimentItem.lastUpdatedAt()); - assertThat(actualExperimentItem.createdBy()).isEqualTo(USER); - assertThat(actualExperimentItem.lastUpdatedBy()).isEqualTo(USER); + assertIgnoredFields(actualExperimentItem, expectedExperimentItem); } } + private void assertIgnoredFields( + List actualExperimentItems, List expectedExperimentItems) { + assertThat(actualExperimentItems).hasSameSizeAs(expectedExperimentItems); + for (int i = 0; i < actualExperimentItems.size(); i++) { + assertIgnoredFields(actualExperimentItems.get(i), expectedExperimentItems.get(i)); + } + } + + private void assertIgnoredFields(ExperimentItem actualExperimentItem, ExperimentItem expectedExperimentItem) { + assertThat(actualExperimentItem.input()).isNull(); + assertThat(actualExperimentItem.output()).isNull(); + assertThat(actualExperimentItem.feedbackScores()).isNull(); + assertThat(actualExperimentItem.createdAt()).isAfter(expectedExperimentItem.createdAt()); + assertThat(actualExperimentItem.lastUpdatedAt()).isAfter(expectedExperimentItem.lastUpdatedAt()); + assertThat(actualExperimentItem.createdBy()).isEqualTo(USER); + assertThat(actualExperimentItem.lastUpdatedBy()).isEqualTo(USER); + } + private void getAndAssertNotFound(UUID id, String apiKey, String workspaceName) { var expectedError = new ErrorMessage(404, "Not found experiment item with id '%s'".formatted(id)); try (var actualResponse = client.target(getExperimentItemsPath()) @@ -1858,6 +1935,47 @@ private void getAndAssertNotFound(UUID id, String apiKey, String workspaceName) } } + private void streamAndAssert( + ExperimentItemStreamRequest request, + List expectedExperimentItems, + List unexpectedExperimentItems, + String apiKey, + String workspaceName) { + try (var actualResponse = client.target(getExperimentItemsPath()) + .path("stream") + .request() + .accept(MediaType.APPLICATION_OCTET_STREAM) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(request))) { + + assertThat(actualResponse.getStatus()).isEqualTo(200); + + var actualExperimentItems = getStreamedItems(actualResponse); + + assertThat(actualExperimentItems) + .usingRecursiveFieldByFieldElementComparatorIgnoringFields(ITEM_IGNORED_FIELDS) + .containsExactlyElementsOf(expectedExperimentItems); + + assertIgnoredFields(actualExperimentItems, expectedExperimentItems); + + assertThat(actualExperimentItems) + .usingRecursiveFieldByFieldElementComparatorIgnoringFields(ITEM_IGNORED_FIELDS) + .doesNotContainAnyElementsOf(unexpectedExperimentItems); + } + } + + private List getStreamedItems(Response response) { + var items = new ArrayList(); + try (var inputStream = response.readEntity(CHUNKED_INPUT_STRING_GENERIC_TYPE)) { + String stringItem; + while ((stringItem = inputStream.read()) != null) { + items.add(JsonUtils.readValue(stringItem, EXPERIMENT_ITEM_TYPE_REFERENCE)); + } + } + return items; + } + private String getExperimentsPath() { return URL_TEMPLATE.formatted(baseURI); }