diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/filter/ExperimentsComparisonField.java b/apps/opik-backend/src/main/java/com/comet/opik/api/filter/ExperimentsComparisonField.java new file mode 100644 index 0000000000..d4153c58fb --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/filter/ExperimentsComparisonField.java @@ -0,0 +1,19 @@ +package com.comet.opik.api.filter; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +@RequiredArgsConstructor +@Getter +public enum ExperimentsComparisonField implements Field { + + INPUT(INPUT_QUERY_PARAM, FieldType.STRING), + EXPECTED_OUTPUT(EXPECTED_OUTPUT_QUERY_PARAM, FieldType.STRING), + OUTPUT(OUTPUT_QUERY_PARAM, FieldType.STRING), + METADATA(METADATA_QUERY_PARAM, FieldType.DICTIONARY), + FEEDBACK_SCORES(FEEDBACK_SCORES_QUERY_PARAM, FieldType.FEEDBACK_SCORES_NUMBER), + ; + + private final String queryParamField; + private final FieldType type; +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/filter/ExperimentsComparisonFilter.java b/apps/opik-backend/src/main/java/com/comet/opik/api/filter/ExperimentsComparisonFilter.java new file mode 100644 index 0000000000..4bf48f7b75 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/filter/ExperimentsComparisonFilter.java @@ -0,0 +1,28 @@ +package com.comet.opik.api.filter; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.type.TypeReference; +import lombok.experimental.SuperBuilder; + +import java.util.List; + +@SuperBuilder(toBuilder = true) +public class ExperimentsComparisonFilter extends FilterImpl { + + public static final TypeReference> LIST_TYPE_REFERENCE = new TypeReference<>() { + }; + + @JsonCreator + public ExperimentsComparisonFilter(@JsonProperty(value = "field", required = true) ExperimentsComparisonField field, + @JsonProperty(value = "operator", required = true) Operator operator, + @JsonProperty("key") String key, + @JsonProperty(value = "value", required = true) String value) { + super(field, operator, key, value); + } + + @Override + public Filter build(String decodedValue) { + return toBuilder().value(decodedValue).build(); + } +} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/filter/Field.java b/apps/opik-backend/src/main/java/com/comet/opik/api/filter/Field.java index 333c16c6db..ff6a490544 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/filter/Field.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/filter/Field.java @@ -11,6 +11,7 @@ public interface Field { String INPUT_QUERY_PARAM = "input"; String OUTPUT_QUERY_PARAM = "output"; String METADATA_QUERY_PARAM = "metadata"; + String EXPECTED_OUTPUT_QUERY_PARAM = "expected_output"; String TAGS_QUERY_PARAM = "tags"; String USAGE_COMPLETION_TOKENS_QUERY_PARAM = "usage.completion_tokens"; String USAGE_PROMPT_TOKENS_QUERY_PARAM = "usage.prompt_tokens"; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java index 757310c117..377a7d26b2 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/DatasetsResource.java @@ -11,6 +11,8 @@ import com.comet.opik.api.DatasetItemsDelete; import com.comet.opik.api.DatasetUpdate; import com.comet.opik.api.ExperimentItem; +import com.comet.opik.api.filter.ExperimentsComparisonFilter; +import com.comet.opik.api.filter.FiltersFactory; import com.comet.opik.domain.DatasetItemService; import com.comet.opik.domain.DatasetService; import com.comet.opik.domain.FeedbackScoreDAO; @@ -82,6 +84,7 @@ public class DatasetsResource { private final @NonNull DatasetService service; private final @NonNull DatasetItemService itemService; private final @NonNull Provider requestContext; + private final @NonNull FiltersFactory filtersFactory; private final @NonNull IdGenerator idGenerator; private final @NonNull Streamer streamer; @@ -348,9 +351,12 @@ public Response findDatasetItemsWithExperimentItems( var experimentIds = getExperimentIds(experimentIdsQueryParam); + var queryFilters = filtersFactory.newFilters(filters, ExperimentsComparisonFilter.LIST_TYPE_REFERENCE); + var datasetItemSearchCriteria = DatasetItemSearchCriteria.builder() .datasetId(datasetId) .experimentIds(experimentIds) + .filters(queryFilters) .entityType(FeedbackScoreDAO.EntityType.TRACE) .build(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/AsyncContextUtils.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/AsyncContextUtils.java index acfc1c1a42..a21f249270 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/AsyncContextUtils.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/AsyncContextUtils.java @@ -1,6 +1,7 @@ package com.comet.opik.domain; -import com.comet.opik.utils.AsyncUtils; +import com.comet.opik.utils.AsyncUtils.ContextAwareAction; +import com.comet.opik.utils.AsyncUtils.ContextAwareStream; import io.r2dbc.spi.Result; import io.r2dbc.spi.Statement; import reactor.core.publisher.Flux; @@ -8,21 +9,21 @@ class AsyncContextUtils { - static AsyncUtils.ContextAwareStream bindWorkspaceIdToFlux(Statement statement) { + static ContextAwareStream bindWorkspaceIdToFlux(Statement statement) { return (userName, workspaceName, workspaceId) -> { statement.bind("workspace_id", workspaceId); return Flux.from(statement.execute()); }; } - static AsyncUtils.ContextAwareAction bindWorkspaceIdToMono(Statement statement) { + static ContextAwareAction bindWorkspaceIdToMono(Statement statement) { return (userName, workspaceName, workspaceId) -> { statement.bind("workspace_id", workspaceId); return Mono.from(statement.execute()); }; } - static AsyncUtils.ContextAwareAction bindUserNameAndWorkspaceContext(Statement statement) { + static ContextAwareAction bindUserNameAndWorkspaceContext(Statement statement) { return (userName, workspaceName, workspaceId) -> { statement.bind("user_name", userName); statement.bind("workspace_id", workspaceId); @@ -31,7 +32,7 @@ static AsyncUtils.ContextAwareAction bindUserNameAndWorkspaceC }; } - static AsyncUtils.ContextAwareStream bindUserNameAndWorkspaceContextToStream( + static ContextAwareStream bindUserNameAndWorkspaceContextToStream( Statement statement) { return (userName, workspaceName, workspaceId) -> { statement.bind("user_name", userName); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java index 236ff63f0f..026ee5f77e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetItemDAO.java @@ -6,6 +6,8 @@ import com.comet.opik.api.ExperimentItem; import com.comet.opik.api.FeedbackScore; import com.comet.opik.api.ScoreSource; +import com.comet.opik.domain.filter.FilterQueryBuilder; +import com.comet.opik.domain.filter.FilterStrategy; import com.comet.opik.infrastructure.db.TransactionTemplateAsync; import com.comet.opik.utils.JsonUtils; import com.fasterxml.jackson.databind.JsonNode; @@ -178,20 +180,51 @@ INSERT INTO dataset_items ( FROM dataset_items WHERE dataset_id = :datasetId AND workspace_id = :workspace_id + + AND + ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id ) AS di INNER JOIN ( SELECT - dataset_item_id - FROM experiment_items + dataset_item_id, + trace_id + FROM experiment_items ei + + INNER JOIN ( + SELECT + id + FROM traces + WHERE workspace_id = :workspace_id + + AND + + + AND id in ( + SELECT + entity_id + FROM ( + SELECT * + FROM feedback_scores + WHERE entity_type = 'trace' + AND workspace_id = :workspace_id + ORDER BY entity_id DESC, last_updated_at DESC + LIMIT 1 BY entity_id, name + ) + GROUP BY entity_id + HAVING + ) + + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS tfs ON ei.trace_id = tfs.id + WHERE experiment_id in :experimentIds AND workspace_id = :workspace_id ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id ) AS ei ON di.id = ei.dataset_item_id - GROUP BY - di.id ; """; @@ -237,6 +270,7 @@ INNER JOIN ( private static final String SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS = """ SELECT di.id AS id, + di.dataset_id AS dataset_id, di.input AS input, di.expected_output AS expected_output, di.metadata AS metadata, @@ -266,13 +300,45 @@ INNER JOIN ( FROM dataset_items WHERE dataset_id = :datasetId AND workspace_id = :workspace_id + + AND + ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id ) AS di INNER JOIN ( SELECT - * - FROM experiment_items + DISTINCT ei.* + FROM experiment_items ei + + INNER JOIN ( + SELECT + id + FROM traces + WHERE workspace_id = :workspace_id + + AND + + + AND id in ( + SELECT + entity_id + FROM ( + SELECT * + FROM feedback_scores + WHERE entity_type = 'trace' + AND workspace_id = :workspace_id + ORDER BY entity_id DESC, last_updated_at DESC + LIMIT 1 BY entity_id, name + ) + GROUP BY entity_id + HAVING + ) + + ORDER BY id DESC, last_updated_at DESC + LIMIT 1 BY id + ) AS tfs ON ei.trace_id = tfs.id + WHERE experiment_id in :experimentIds AND workspace_id = :workspace_id ORDER BY id DESC, last_updated_at DESC @@ -322,6 +388,7 @@ LEFT JOIN ( ) AS tfs ON ei.trace_id = tfs.id GROUP BY di.id, + di.dataset_id, di.input, di.expected_output, di.metadata, @@ -338,6 +405,7 @@ LEFT JOIN ( """; private final @NonNull TransactionTemplateAsync asyncTemplate; + private final @NonNull FilterQueryBuilder filterQueryBuilder; @Override @Trace(dispatcher = true) @@ -465,12 +533,14 @@ private List getFeedbackScores(Object feedbackScoresRaw) { if (feedbackScoresRaw instanceof List[] feedbackScoresArray) { var feedbackScores = Arrays.stream(feedbackScoresArray) .filter(feedbackScore -> CollectionUtils.isNotEmpty(feedbackScore) && - !CLICKHOUSE_FIXED_STRING_UUID_FIELD_NULL_VALUE.equals(feedbackScore.get(0).toString())) + !CLICKHOUSE_FIXED_STRING_UUID_FIELD_NULL_VALUE.equals(feedbackScore.getFirst().toString())) .map(feedbackScore -> FeedbackScore.builder() .name(feedbackScore.get(1).toString()) - .categoryName(Optional.ofNullable(feedbackScore.get(2)).map(Object::toString).orElse(null)) + .categoryName(Optional.ofNullable(feedbackScore.get(2)).map(Object::toString) + .filter(StringUtils::isNotEmpty).orElse(null)) .value(new BigDecimal(feedbackScore.get(3).toString())) - .reason(Optional.ofNullable(feedbackScore.get(4)).map(Object::toString).orElse(null)) + .reason(Optional.ofNullable(feedbackScore.get(4)).map(Object::toString) + .filter(StringUtils::isNotEmpty).orElse(null)) .source(ScoreSource.fromString(feedbackScore.get(5).toString())) .build()) .toList(); @@ -606,6 +676,34 @@ public Mono getItems(@NonNull UUID datasetId, int page, int siz }))); } + private ST newFindTemplate(String query, DatasetItemSearchCriteria datasetItemSearchCriteria) { + var template = new ST(query); + + Optional.ofNullable(datasetItemSearchCriteria.filters()) + .ifPresent(filters -> { + filterQueryBuilder.toAnalyticsDbFilters(filters, FilterStrategy.DATASET_ITEM) + .ifPresent(datasetItemFilters -> template.add("dataset_item_filters", datasetItemFilters)); + + filterQueryBuilder.toAnalyticsDbFilters(filters, FilterStrategy.EXPERIMENT_ITEM) + .ifPresent(experimentItemFilters -> template.add("experiment_item_filters", + experimentItemFilters)); + + filterQueryBuilder.toAnalyticsDbFilters(filters, FilterStrategy.FEEDBACK_SCORES) + .ifPresent(scoresFilters -> template.add("feedback_scores_filters", scoresFilters)); + }); + + return template; + } + + private void bindSearchCriteria(DatasetItemSearchCriteria datasetItemSearchCriteria, Statement statement) { + Optional.ofNullable(datasetItemSearchCriteria.filters()) + .ifPresent(filters -> { + filterQueryBuilder.bind(statement, filters, FilterStrategy.DATASET_ITEM); + filterQueryBuilder.bind(statement, filters, FilterStrategy.EXPERIMENT_ITEM); + filterQueryBuilder.bind(statement, filters, FilterStrategy.FEEDBACK_SCORES); + }); + } + @Override @Trace(dispatcher = true) public Mono getItems( @@ -615,37 +713,42 @@ public Mono getItems( Segment segmentCount = startSegment("dataset_items", "Clickhouse", "select_dataset_items_filters_count"); - return makeMonoContextAware((userName, workspaceName, workspaceId) -> asyncTemplate.nonTransaction( - connection -> Flux - .from(connection - .createStatement( - SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_COUNT) + return asyncTemplate.nonTransaction(connection -> { + + ST countTemplate = newFindTemplate(SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_COUNT, + datasetItemSearchCriteria); + + var statement = connection.createStatement(countTemplate.render()) + .bind("datasetId", datasetItemSearchCriteria.datasetId()) + .bind("experimentIds", datasetItemSearchCriteria.experimentIds().toArray(UUID[]::new)); + + bindSearchCriteria(datasetItemSearchCriteria, statement); + + return makeFluxContextAware(bindWorkspaceIdToFlux(statement)) + .doFinally(signalType -> segmentCount.end()) + .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, Long.class))) + .reduce(0L, Long::sum) + .flatMap(count -> { + Segment segment = startSegment("dataset_items", "Clickhouse", "select_dataset_items_filters"); + + ST selectTemplate = newFindTemplate(SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS, + datasetItemSearchCriteria); + + var selectStatement = connection.createStatement(selectTemplate.render()) .bind("datasetId", datasetItemSearchCriteria.datasetId()) - .bind("experimentIds", datasetItemSearchCriteria.experimentIds()) - .bind("workspace_id", workspaceId) - .execute()) - .doFinally(signalType -> segmentCount.end()) - .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, Long.class))) - .reduce(0L, Long::sum) - .flatMap(count -> { - Segment segment = startSegment("dataset_items", "Clickhouse", - "select_dataset_items_filters"); - - return Flux - .from(connection - .createStatement( - SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS) - .bind("datasetId", datasetItemSearchCriteria.datasetId()) - .bind("experimentIds", datasetItemSearchCriteria.experimentIds()) - .bind("entityType", datasetItemSearchCriteria.entityType().getType()) - .bind("workspace_id", workspaceId) - .bind("limit", size) - .bind("offset", (page - 1) * size) - .execute()) - .doFinally(signalType -> segment.end()) - .flatMap(this::mapItem) - .collectList() - .flatMap(items -> Mono.just(new DatasetItemPage(items, page, items.size(), count))); - }))); + .bind("experimentIds", datasetItemSearchCriteria.experimentIds().toArray(UUID[]::new)) + .bind("entityType", datasetItemSearchCriteria.entityType().getType()) + .bind("limit", size) + .bind("offset", (page - 1) * size); + + bindSearchCriteria(datasetItemSearchCriteria, selectStatement); + + return makeFluxContextAware(bindWorkspaceIdToFlux(selectStatement)) + .doFinally(signalType -> segment.end()) + .flatMap(this::mapItem) + .collectList() + .flatMap(items -> Mono.just(new DatasetItemPage(items, page, items.size(), count))); + }); + }); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/filter/FilterQueryBuilder.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/filter/FilterQueryBuilder.java index c645b6c752..c80c74aac0 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/filter/FilterQueryBuilder.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/filter/FilterQueryBuilder.java @@ -1,5 +1,6 @@ package com.comet.opik.domain.filter; +import com.comet.opik.api.filter.ExperimentsComparisonField; import com.comet.opik.api.filter.Field; import com.comet.opik.api.filter.FieldType; import com.comet.opik.api.filter.Filter; @@ -33,6 +34,7 @@ public class FilterQueryBuilder { private static final String INPUT_ANALYTICS_DB = "input"; private static final String OUTPUT_ANALYTICS_DB = "output"; private static final String METADATA_ANALYTICS_DB = "metadata"; + private static final String EXPECTED_OUTPUT_ANALYTICS_DB = "expected_output"; private static final String TAGS_ANALYTICS_DB = "tags"; private static final String USAGE_COMPLETION_TOKENS_ANALYTICS_DB = "usage['completion_tokens']"; private static final String USAGE_PROMPT_TOKENS_ANALYTICS_DB = "usage['prompt_tokens']"; @@ -117,6 +119,15 @@ public class FilterQueryBuilder { .put(SpanField.FEEDBACK_SCORES, VALUE_ANALYTICS_DB) .build()); + private static final Map EXPERIMENTS_COMPARISON_FIELDS_MAP = new EnumMap<>( + ImmutableMap.builder() + .put(ExperimentsComparisonField.OUTPUT, OUTPUT_ANALYTICS_DB) + .put(ExperimentsComparisonField.INPUT, INPUT_ANALYTICS_DB) + .put(ExperimentsComparisonField.EXPECTED_OUTPUT, EXPECTED_OUTPUT_ANALYTICS_DB) + .put(ExperimentsComparisonField.METADATA, METADATA_ANALYTICS_DB) + .put(ExperimentsComparisonField.FEEDBACK_SCORES, VALUE_ANALYTICS_DB) + .build()); + private static final Map> FILTER_STRATEGY_MAP = new EnumMap<>(Map.of( FilterStrategy.TRACE, EnumSet.copyOf(ImmutableSet.builder() .add(TraceField.ID) @@ -149,7 +160,16 @@ public class FilterQueryBuilder { FilterStrategy.FEEDBACK_SCORES, ImmutableSet.builder() .add(TraceField.FEEDBACK_SCORES) .add(SpanField.FEEDBACK_SCORES) - .build())); + .add(ExperimentsComparisonField.FEEDBACK_SCORES) + .build(), + FilterStrategy.EXPERIMENT_ITEM, EnumSet.copyOf(ImmutableSet.builder() + .add(ExperimentsComparisonField.OUTPUT) + .build()), + FilterStrategy.DATASET_ITEM, EnumSet.copyOf(ImmutableSet.builder() + .add(ExperimentsComparisonField.INPUT) + .add(ExperimentsComparisonField.EXPECTED_OUTPUT) + .add(ExperimentsComparisonField.METADATA) + .build()))); private static final Set KEY_SUPPORTED_FIELDS_SET = EnumSet.of( FieldType.DICTIONARY, @@ -182,12 +202,15 @@ private String toAnalyticsDbFilter(Filter filter, int i) { } private String getAnalyticsDbField(Field field) { - if (field instanceof TraceField) { - return TRACE_FIELDS_MAP.get(field); - } else if (field instanceof SpanField) { - return SPAN_FIELDS_MAP.get(field); - } - throw new IllegalArgumentException("Unknown type for field '%s', type '%s'".formatted(field, field.getClass())); + + return switch (field) { + case TraceField traceField -> TRACE_FIELDS_MAP.get(traceField); + case SpanField spanField -> SPAN_FIELDS_MAP.get(spanField); + case ExperimentsComparisonField experimentsComparisonField -> + EXPERIMENTS_COMPARISON_FIELDS_MAP.get(experimentsComparisonField); + default -> throw new IllegalArgumentException( + "Unknown type for field '%s', type '%s'".formatted(field, field.getClass())); + }; } public Statement bind( diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/filter/FilterStrategy.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/filter/FilterStrategy.java index 51ca9c00f5..49b23f64bb 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/filter/FilterStrategy.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/filter/FilterStrategy.java @@ -4,5 +4,7 @@ public enum FilterStrategy { TRACE, TRACE_AGGREGATION, SPAN, + EXPERIMENT_ITEM, + DATASET_ITEM, FEEDBACK_SCORES } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/internal/UsageResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/internal/UsageResourceTest.java index 807ecab9f6..82419bec20 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/internal/UsageResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/internal/UsageResourceTest.java @@ -55,7 +55,8 @@ public class UsageResourceTest { private static final MySQLContainer MYSQL_CONTAINER = MySQLContainerUtils.newMySQLContainer(); - private static final ClickHouseContainer CLICK_HOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer(false); + private static final ClickHouseContainer CLICK_HOUSE_CONTAINER = ClickHouseContainerUtils + .newClickHouseContainer(false); @RegisterExtension private static final TestDropwizardAppExtension app; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceIntegrationTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceIntegrationTest.java index 4ae4e52348..7da4ae740f 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceIntegrationTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceIntegrationTest.java @@ -2,9 +2,11 @@ import com.comet.opik.api.DatasetItem; import com.comet.opik.api.DatasetItemStreamRequest; +import com.comet.opik.api.filter.FiltersFactory; import com.comet.opik.domain.DatasetItemService; import com.comet.opik.domain.DatasetService; import com.comet.opik.domain.Streamer; +import com.comet.opik.domain.filter.FilterQueryBuilder; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.json.JsonNodeMessageBodyWriter; import com.comet.opik.podam.PodamFactoryUtils; @@ -43,7 +45,8 @@ class DatasetsResourceIntegrationTest { private static final ResourceExtension EXT = ResourceExtension.builder() .addResource(new DatasetsResource( - service, itemService, () -> requestContext, timeBasedGenerator::generate, new Streamer())) + service, itemService, () -> requestContext, new FiltersFactory(new FilterQueryBuilder()), + timeBasedGenerator::generate, new Streamer())) .addProvider(JsonNodeMessageBodyWriter.class) .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) .build(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index 4707fc4304..e1b259fb1b 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -14,9 +14,14 @@ import com.comet.opik.api.FeedbackScoreBatch; import com.comet.opik.api.FeedbackScoreBatchItem; import com.comet.opik.api.Project; +import com.comet.opik.api.ScoreSource; import com.comet.opik.api.Span; import com.comet.opik.api.Trace; import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.api.filter.ExperimentsComparisonField; +import com.comet.opik.api.filter.ExperimentsComparisonFilter; +import com.comet.opik.api.filter.Filter; +import com.comet.opik.api.filter.Operator; import com.comet.opik.api.resources.utils.AuthTestUtils; import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; import com.comet.opik.api.resources.utils.ClientSupportUtils; @@ -40,6 +45,8 @@ import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.RandomStringUtils; import org.glassfish.jersey.client.ChunkedInput; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; @@ -62,10 +69,14 @@ import uk.co.jemos.podam.api.PodamFactory; import java.math.BigDecimal; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.UUID; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -3240,6 +3251,332 @@ void findInvalidExperimentIds(String experimentIds) { assertThat(actualErrorMessage).isEqualTo(expectedErrorMessage); } } + + @ParameterizedTest + @MethodSource + void find__whenFilteringBySupportedFields__thenReturnMatchingRows(Filter filter) { + var workspaceName = UUID.randomUUID().toString(); + var apiKey = UUID.randomUUID().toString(); + var workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var dataset = factory.manufacturePojo(Dataset.class); + var datasetId = createAndAssert(dataset, apiKey, workspaceName); + + List items = new ArrayList<>(); + createDatasetItems(items); + var batch = DatasetItemBatch.builder() + .items(items) + .datasetId(datasetId) + .build(); + putAndAssert(batch, workspaceName, apiKey); + + String projectName = RandomStringUtils.randomAlphanumeric(20); + List traces = new ArrayList<>(); + createTraces(items, projectName, workspaceName, apiKey, traces); + + UUID experimentId = GENERATOR.generate(); + + List scores = new ArrayList<>(); + createScores(traces, projectName, scores); + createScoreAndAssert(new FeedbackScoreBatch(scores), apiKey, workspaceName); + + List experimentItems = new ArrayList<>(); + createExperimentItems(items, traces, scores, experimentId, experimentItems); + + createAndAssert( + ExperimentItemsBatch.builder() + .experimentItems(Set.copyOf(experimentItems)) + .build(), + apiKey, + workspaceName); + + List filters = List.of(filter); + + try (var actualResponse = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .path(datasetId.toString()) + .path(DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_PATH) + .queryParam("experiment_ids", JsonUtils.writeValueAsString(List.of(experimentId))) + .queryParam("filters", toURLEncodedQueryParam(filters)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + assertThat(actualResponse.hasEntity()).isTrue(); + + var actualPage = actualResponse.readEntity(DatasetItemPage.class); + + assertThat(actualPage.size()).isEqualTo(1); + assertThat(actualPage.total()).isEqualTo(1); + assertThat(actualPage.page()).isEqualTo(1); + assertThat(actualPage.content()).hasSize(1); + assertDatasetItemPage(actualPage, items, experimentItems); + } + } + + Stream find__whenFilteringBySupportedFields__thenReturnMatchingRows() { + return Stream.of( + arguments(new ExperimentsComparisonFilter(ExperimentsComparisonField.FEEDBACK_SCORES, + Operator.EQUAL, "sql_cost", "10")), + arguments(new ExperimentsComparisonFilter(ExperimentsComparisonField.INPUT, Operator.CONTAINS, null, + "sql_cost")), + arguments(new ExperimentsComparisonFilter(ExperimentsComparisonField.OUTPUT, Operator.CONTAINS, + null, "sql_cost")), + arguments(new ExperimentsComparisonFilter(ExperimentsComparisonField.EXPECTED_OUTPUT, + Operator.CONTAINS, null, "sql_cost")), + arguments(new ExperimentsComparisonFilter(ExperimentsComparisonField.METADATA, Operator.EQUAL, + "sql_cost", "10"))); + } + + private void createExperimentItems(List items, List traces, + List scores, UUID experimentId, List experimentItems) { + for (int i = 0; i < items.size(); i++) { + var item = items.get(i); + var trace = traces.get(i); + var score = scores.get(i); + + var experimentItem = ExperimentItem.builder() + .id(GENERATOR.generate()) + .datasetItemId(item.id()) + .traceId(trace.id()) + .experimentId(experimentId) + .input(trace.input()) + .output(trace.output()) + .feedbackScores(Stream.of(score) + .map(FeedbackScoreMapper.INSTANCE::toFeedbackScore) + .toList()) + .build(); + + experimentItems.add(experimentItem); + } + } + + private void createScores(List traces, String projectName, List scores) { + for (int i = 0; i < traces.size(); i++) { + var trace = traces.get(i); + + var score = FeedbackScoreBatchItem.builder() + .name("sql_cost") + .value(BigDecimal.valueOf(i == 0 ? 10 : i)) + .source(ScoreSource.SDK) + .id(trace.id()) + .projectName(projectName) + .build(); + + scores.add(score); + } + } + + private void createTraces(List items, String projectName, String workspaceName, String apiKey, + List traces) { + for (int i = 0; i < items.size(); i++) { + var item = items.get(i); + var trace = Trace.builder() + .id(GENERATOR.generate()) + .input(item.input()) + .output(item.expectedOutput()) + .projectName(projectName) + .startTime(Instant.now()) + .name("trace-" + i) + .build(); + + createAndAssert(trace, workspaceName, apiKey); + traces.add(trace); + } + } + + private void createDatasetItems(List items) { + for (int i = 0; i < 5; i++) { + if (i == 0) { + DatasetItem item = factory.manufacturePojo(DatasetItem.class) + .toBuilder() + .input(JsonUtils + .getJsonNodeFromString(JsonUtils.writeValueAsString(Map.of("input", "sql_cost")))) + .expectedOutput(JsonUtils + .getJsonNodeFromString(JsonUtils.writeValueAsString(Map.of("output", "sql_cost")))) + .metadata(JsonUtils + .getJsonNodeFromString(JsonUtils.writeValueAsString(Map.of("sql_cost", 10)))) + .source(DatasetItemSource.SDK) + .traceId(null) + .spanId(null) + .build(); + + items.add(item); + } else { + var item = factory.manufacturePojo(DatasetItem.class); + + items.add(item); + } + } + } + + private void createScoreAndAssert(FeedbackScoreBatch feedbackScoreBatch, String apiKey, String workspaceName) { + try (var actualResponse = client.target(getTracesPath()) + .path("feedback-scores") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .put(Entity.json(feedbackScoreBatch))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204); + assertThat(actualResponse.hasEntity()).isFalse(); + } + } + + private String toURLEncodedQueryParam(List filters) { + return CollectionUtils.isEmpty(filters) + ? null + : URLEncoder.encode(JsonUtils.writeValueAsString(filters), StandardCharsets.UTF_8); + } + + @ParameterizedTest + @MethodSource + void find__whenFilterInvalidOperatorForFieldType__thenReturn400(ExperimentsComparisonFilter filter) { + var workspaceName = UUID.randomUUID().toString(); + var apiKey = UUID.randomUUID().toString(); + var workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var expectedError = new io.dropwizard.jersey.errors.ErrorMessage( + 400, + "Invalid operator '%s' for field '%s' of type '%s'".formatted( + filter.operator().getQueryParamOperator(), + filter.field().getQueryParamField(), + filter.field().getType())); + + var datasetId = GENERATOR.generate(); + var experimentId = GENERATOR.generate(); + var filters = List.of(filter); + + try (var actualResponse = client.target(BASE_RESOURCE_URI.formatted(baseURI)) + .path(datasetId.toString()) + .path(DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_PATH) + .queryParam("experiment_ids", JsonUtils.writeValueAsString(List.of(experimentId))) + .queryParam("filters", toURLEncodedQueryParam(filters)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(400); + assertThat(actualResponse.hasEntity()).isTrue(); + + var actualError = actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class); + assertThat(actualError).isEqualTo(expectedError); + } + + } + + static Stream find__whenFilterInvalidOperatorForFieldType__thenReturn400() { + return Stream.of( + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.FEEDBACK_SCORES) + .operator(Operator.CONTAINS) + .value(RandomStringUtils.randomAlphanumeric(10)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.FEEDBACK_SCORES) + .operator(Operator.NOT_CONTAINS) + .value(RandomStringUtils.randomAlphanumeric(10)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.FEEDBACK_SCORES) + .operator(Operator.STARTS_WITH) + .value(RandomStringUtils.randomAlphanumeric(10)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.FEEDBACK_SCORES) + .operator(Operator.ENDS_WITH) + .value(RandomStringUtils.randomAlphanumeric(10)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.INPUT) + .operator(Operator.GREATER_THAN) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.INPUT) + .operator(Operator.LESS_THAN) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.INPUT) + .operator(Operator.GREATER_THAN_EQUAL) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.INPUT) + .operator(Operator.LESS_THAN_EQUAL) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.OUTPUT) + .operator(Operator.GREATER_THAN) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.OUTPUT) + .operator(Operator.LESS_THAN) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.OUTPUT) + .operator(Operator.GREATER_THAN_EQUAL) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.OUTPUT) + .operator(Operator.LESS_THAN_EQUAL) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.EXPECTED_OUTPUT) + .operator(Operator.GREATER_THAN) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.EXPECTED_OUTPUT) + .operator(Operator.LESS_THAN) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.EXPECTED_OUTPUT) + .operator(Operator.GREATER_THAN_EQUAL) + .value(RandomStringUtils.randomNumeric(3)) + .build()), + Arguments.of(ExperimentsComparisonFilter.builder() + .field(ExperimentsComparisonField.EXPECTED_OUTPUT) + .operator(Operator.LESS_THAN_EQUAL) + .value(RandomStringUtils.randomNumeric(3)) + .build())); + } + } + + private void assertDatasetItemPage(DatasetItemPage actualPage, List items, + List experimentItems) { + assertThat(actualPage.content().getFirst()) + .usingRecursiveComparison() + .ignoringFields(IGNORED_FIELDS_DATA_ITEM) + .isEqualTo(items.getFirst()); + + var actualExperimentItems = actualPage.content().getFirst().experimentItems(); + assertThat(actualExperimentItems).hasSize(1); + assertThat(actualExperimentItems.getFirst()) + .usingRecursiveComparison() + .ignoringFields(IGNORED_FIELDS_LIST) + .isEqualTo(experimentItems.getFirst()); + + var actualFeedbackScores = actualExperimentItems.getFirst().feedbackScores(); + assertThat(actualFeedbackScores).hasSize(1); + + assertThat(actualFeedbackScores.getFirst()) + .usingRecursiveComparison() + .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .isEqualTo(experimentItems.getFirst().feedbackScores().getFirst()); } private void putAndAssert(DatasetItemBatch batch, String workspaceName, String apiKey) {