Skip to content

Commit

Permalink
[OPIK-147] Add filters to dataset item experiment compare endpoint (#359
Browse files Browse the repository at this point in the history
)

* [OPIK-147] Add filters to dataset item experiment compare endpoints

* Add invalid filters tests
  • Loading branch information
thiagohora authored Oct 9, 2024
1 parent 63d162b commit 373194c
Show file tree
Hide file tree
Showing 11 changed files with 578 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.comet.opik.api.filter;

import lombok.Getter;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
@Getter
public enum ExperimentsComparisonField implements Field {

INPUT(INPUT_QUERY_PARAM, FieldType.STRING),
EXPECTED_OUTPUT(EXPECTED_OUTPUT_QUERY_PARAM, FieldType.STRING),
OUTPUT(OUTPUT_QUERY_PARAM, FieldType.STRING),
METADATA(METADATA_QUERY_PARAM, FieldType.DICTIONARY),
FEEDBACK_SCORES(FEEDBACK_SCORES_QUERY_PARAM, FieldType.FEEDBACK_SCORES_NUMBER),
;

private final String queryParamField;
private final FieldType type;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.comet.opik.api.filter;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import lombok.experimental.SuperBuilder;

import java.util.List;

@SuperBuilder(toBuilder = true)
public class ExperimentsComparisonFilter extends FilterImpl {

public static final TypeReference<List<ExperimentsComparisonFilter>> LIST_TYPE_REFERENCE = new TypeReference<>() {
};

@JsonCreator
public ExperimentsComparisonFilter(@JsonProperty(value = "field", required = true) ExperimentsComparisonField field,
@JsonProperty(value = "operator", required = true) Operator operator,
@JsonProperty("key") String key,
@JsonProperty(value = "value", required = true) String value) {
super(field, operator, key, value);
}

@Override
public Filter build(String decodedValue) {
return toBuilder().value(decodedValue).build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public interface Field {
String INPUT_QUERY_PARAM = "input";
String OUTPUT_QUERY_PARAM = "output";
String METADATA_QUERY_PARAM = "metadata";
String EXPECTED_OUTPUT_QUERY_PARAM = "expected_output";
String TAGS_QUERY_PARAM = "tags";
String USAGE_COMPLETION_TOKENS_QUERY_PARAM = "usage.completion_tokens";
String USAGE_PROMPT_TOKENS_QUERY_PARAM = "usage.prompt_tokens";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import com.comet.opik.api.DatasetItemsDelete;
import com.comet.opik.api.DatasetUpdate;
import com.comet.opik.api.ExperimentItem;
import com.comet.opik.api.filter.ExperimentsComparisonFilter;
import com.comet.opik.api.filter.FiltersFactory;
import com.comet.opik.domain.DatasetItemService;
import com.comet.opik.domain.DatasetService;
import com.comet.opik.domain.FeedbackScoreDAO;
Expand Down Expand Up @@ -82,6 +84,7 @@ public class DatasetsResource {
private final @NonNull DatasetService service;
private final @NonNull DatasetItemService itemService;
private final @NonNull Provider<RequestContext> requestContext;
private final @NonNull FiltersFactory filtersFactory;
private final @NonNull IdGenerator idGenerator;
private final @NonNull Streamer streamer;

Expand Down Expand Up @@ -348,9 +351,12 @@ public Response findDatasetItemsWithExperimentItems(

var experimentIds = getExperimentIds(experimentIdsQueryParam);

var queryFilters = filtersFactory.newFilters(filters, ExperimentsComparisonFilter.LIST_TYPE_REFERENCE);

var datasetItemSearchCriteria = DatasetItemSearchCriteria.builder()
.datasetId(datasetId)
.experimentIds(experimentIds)
.filters(queryFilters)
.entityType(FeedbackScoreDAO.EntityType.TRACE)
.build();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
package com.comet.opik.domain;

import com.comet.opik.utils.AsyncUtils;
import com.comet.opik.utils.AsyncUtils.ContextAwareAction;
import com.comet.opik.utils.AsyncUtils.ContextAwareStream;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Statement;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

class AsyncContextUtils {

static AsyncUtils.ContextAwareStream<? extends Result> bindWorkspaceIdToFlux(Statement statement) {
static ContextAwareStream<Result> bindWorkspaceIdToFlux(Statement statement) {
return (userName, workspaceName, workspaceId) -> {
statement.bind("workspace_id", workspaceId);
return Flux.from(statement.execute());
};
}

static AsyncUtils.ContextAwareAction<? extends Result> bindWorkspaceIdToMono(Statement statement) {
static ContextAwareAction<Result> bindWorkspaceIdToMono(Statement statement) {
return (userName, workspaceName, workspaceId) -> {
statement.bind("workspace_id", workspaceId);
return Mono.from(statement.execute());
};
}

static AsyncUtils.ContextAwareAction<? extends Result> bindUserNameAndWorkspaceContext(Statement statement) {
static ContextAwareAction<Result> bindUserNameAndWorkspaceContext(Statement statement) {
return (userName, workspaceName, workspaceId) -> {
statement.bind("user_name", userName);
statement.bind("workspace_id", workspaceId);
Expand All @@ -31,7 +32,7 @@ static AsyncUtils.ContextAwareAction<? extends Result> bindUserNameAndWorkspaceC
};
}

static AsyncUtils.ContextAwareStream<? extends Result> bindUserNameAndWorkspaceContextToStream(
static ContextAwareStream<Result> bindUserNameAndWorkspaceContextToStream(
Statement statement) {
return (userName, workspaceName, workspaceId) -> {
statement.bind("user_name", userName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import com.comet.opik.api.ExperimentItem;
import com.comet.opik.api.FeedbackScore;
import com.comet.opik.api.ScoreSource;
import com.comet.opik.domain.filter.FilterQueryBuilder;
import com.comet.opik.domain.filter.FilterStrategy;
import com.comet.opik.infrastructure.db.TransactionTemplateAsync;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.databind.JsonNode;
Expand Down Expand Up @@ -178,20 +180,51 @@ INSERT INTO dataset_items (
FROM dataset_items
WHERE dataset_id = :datasetId
AND workspace_id = :workspace_id
<if(dataset_item_filters)>
AND <dataset_item_filters>
<endif>
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
) AS di
INNER JOIN (
SELECT
dataset_item_id
FROM experiment_items
dataset_item_id,
trace_id
FROM experiment_items ei
<if(experiment_item_filters || feedback_scores_filters)>
INNER JOIN (
SELECT
id
FROM traces
WHERE workspace_id = :workspace_id
<if(experiment_item_filters)>
AND <experiment_item_filters>
<endif>
<if(feedback_scores_filters)>
AND id in (
SELECT
entity_id
FROM (
SELECT *
FROM feedback_scores
WHERE entity_type = 'trace'
AND workspace_id = :workspace_id
ORDER BY entity_id DESC, last_updated_at DESC
LIMIT 1 BY entity_id, name
)
GROUP BY entity_id
HAVING <feedback_scores_filters>
)
<endif>
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
) AS tfs ON ei.trace_id = tfs.id
<endif>
WHERE experiment_id in :experimentIds
AND workspace_id = :workspace_id
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
) AS ei ON di.id = ei.dataset_item_id
GROUP BY
di.id
;
""";

Expand Down Expand Up @@ -237,6 +270,7 @@ INNER JOIN (
private static final String SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS = """
SELECT
di.id AS id,
di.dataset_id AS dataset_id,
di.input AS input,
di.expected_output AS expected_output,
di.metadata AS metadata,
Expand Down Expand Up @@ -266,13 +300,45 @@ INNER JOIN (
FROM dataset_items
WHERE dataset_id = :datasetId
AND workspace_id = :workspace_id
<if(dataset_item_filters)>
AND <dataset_item_filters>
<endif>
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
) AS di
INNER JOIN (
SELECT
*
FROM experiment_items
DISTINCT ei.*
FROM experiment_items ei
<if(experiment_item_filters || feedback_scores_filters)>
INNER JOIN (
SELECT
id
FROM traces
WHERE workspace_id = :workspace_id
<if(experiment_item_filters)>
AND <experiment_item_filters>
<endif>
<if(feedback_scores_filters)>
AND id in (
SELECT
entity_id
FROM (
SELECT *
FROM feedback_scores
WHERE entity_type = 'trace'
AND workspace_id = :workspace_id
ORDER BY entity_id DESC, last_updated_at DESC
LIMIT 1 BY entity_id, name
)
GROUP BY entity_id
HAVING <feedback_scores_filters>
)
<endif>
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
) AS tfs ON ei.trace_id = tfs.id
<endif>
WHERE experiment_id in :experimentIds
AND workspace_id = :workspace_id
ORDER BY id DESC, last_updated_at DESC
Expand Down Expand Up @@ -322,6 +388,7 @@ LEFT JOIN (
) AS tfs ON ei.trace_id = tfs.id
GROUP BY
di.id,
di.dataset_id,
di.input,
di.expected_output,
di.metadata,
Expand All @@ -338,6 +405,7 @@ LEFT JOIN (
""";

private final @NonNull TransactionTemplateAsync asyncTemplate;
private final @NonNull FilterQueryBuilder filterQueryBuilder;

@Override
@Trace(dispatcher = true)
Expand Down Expand Up @@ -465,12 +533,14 @@ private List<FeedbackScore> getFeedbackScores(Object feedbackScoresRaw) {
if (feedbackScoresRaw instanceof List[] feedbackScoresArray) {
var feedbackScores = Arrays.stream(feedbackScoresArray)
.filter(feedbackScore -> CollectionUtils.isNotEmpty(feedbackScore) &&
!CLICKHOUSE_FIXED_STRING_UUID_FIELD_NULL_VALUE.equals(feedbackScore.get(0).toString()))
!CLICKHOUSE_FIXED_STRING_UUID_FIELD_NULL_VALUE.equals(feedbackScore.getFirst().toString()))
.map(feedbackScore -> FeedbackScore.builder()
.name(feedbackScore.get(1).toString())
.categoryName(Optional.ofNullable(feedbackScore.get(2)).map(Object::toString).orElse(null))
.categoryName(Optional.ofNullable(feedbackScore.get(2)).map(Object::toString)
.filter(StringUtils::isNotEmpty).orElse(null))
.value(new BigDecimal(feedbackScore.get(3).toString()))
.reason(Optional.ofNullable(feedbackScore.get(4)).map(Object::toString).orElse(null))
.reason(Optional.ofNullable(feedbackScore.get(4)).map(Object::toString)
.filter(StringUtils::isNotEmpty).orElse(null))
.source(ScoreSource.fromString(feedbackScore.get(5).toString()))
.build())
.toList();
Expand Down Expand Up @@ -606,6 +676,34 @@ public Mono<DatasetItemPage> getItems(@NonNull UUID datasetId, int page, int siz
})));
}

private ST newFindTemplate(String query, DatasetItemSearchCriteria datasetItemSearchCriteria) {
var template = new ST(query);

Optional.ofNullable(datasetItemSearchCriteria.filters())
.ifPresent(filters -> {
filterQueryBuilder.toAnalyticsDbFilters(filters, FilterStrategy.DATASET_ITEM)
.ifPresent(datasetItemFilters -> template.add("dataset_item_filters", datasetItemFilters));

filterQueryBuilder.toAnalyticsDbFilters(filters, FilterStrategy.EXPERIMENT_ITEM)
.ifPresent(experimentItemFilters -> template.add("experiment_item_filters",
experimentItemFilters));

filterQueryBuilder.toAnalyticsDbFilters(filters, FilterStrategy.FEEDBACK_SCORES)
.ifPresent(scoresFilters -> template.add("feedback_scores_filters", scoresFilters));
});

return template;
}

private void bindSearchCriteria(DatasetItemSearchCriteria datasetItemSearchCriteria, Statement statement) {
Optional.ofNullable(datasetItemSearchCriteria.filters())
.ifPresent(filters -> {
filterQueryBuilder.bind(statement, filters, FilterStrategy.DATASET_ITEM);
filterQueryBuilder.bind(statement, filters, FilterStrategy.EXPERIMENT_ITEM);
filterQueryBuilder.bind(statement, filters, FilterStrategy.FEEDBACK_SCORES);
});
}

@Override
@Trace(dispatcher = true)
public Mono<DatasetItemPage> getItems(
Expand All @@ -615,37 +713,42 @@ public Mono<DatasetItemPage> getItems(

Segment segmentCount = startSegment("dataset_items", "Clickhouse", "select_dataset_items_filters_count");

return makeMonoContextAware((userName, workspaceName, workspaceId) -> asyncTemplate.nonTransaction(
connection -> Flux
.from(connection
.createStatement(
SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_COUNT)
return asyncTemplate.nonTransaction(connection -> {

ST countTemplate = newFindTemplate(SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS_COUNT,
datasetItemSearchCriteria);

var statement = connection.createStatement(countTemplate.render())
.bind("datasetId", datasetItemSearchCriteria.datasetId())
.bind("experimentIds", datasetItemSearchCriteria.experimentIds().toArray(UUID[]::new));

bindSearchCriteria(datasetItemSearchCriteria, statement);

return makeFluxContextAware(bindWorkspaceIdToFlux(statement))
.doFinally(signalType -> segmentCount.end())
.flatMap(result -> result.map((row, rowMetadata) -> row.get(0, Long.class)))
.reduce(0L, Long::sum)
.flatMap(count -> {
Segment segment = startSegment("dataset_items", "Clickhouse", "select_dataset_items_filters");

ST selectTemplate = newFindTemplate(SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS,
datasetItemSearchCriteria);

var selectStatement = connection.createStatement(selectTemplate.render())
.bind("datasetId", datasetItemSearchCriteria.datasetId())
.bind("experimentIds", datasetItemSearchCriteria.experimentIds())
.bind("workspace_id", workspaceId)
.execute())
.doFinally(signalType -> segmentCount.end())
.flatMap(result -> result.map((row, rowMetadata) -> row.get(0, Long.class)))
.reduce(0L, Long::sum)
.flatMap(count -> {
Segment segment = startSegment("dataset_items", "Clickhouse",
"select_dataset_items_filters");

return Flux
.from(connection
.createStatement(
SELECT_DATASET_ITEMS_WITH_EXPERIMENT_ITEMS)
.bind("datasetId", datasetItemSearchCriteria.datasetId())
.bind("experimentIds", datasetItemSearchCriteria.experimentIds())
.bind("entityType", datasetItemSearchCriteria.entityType().getType())
.bind("workspace_id", workspaceId)
.bind("limit", size)
.bind("offset", (page - 1) * size)
.execute())
.doFinally(signalType -> segment.end())
.flatMap(this::mapItem)
.collectList()
.flatMap(items -> Mono.just(new DatasetItemPage(items, page, items.size(), count)));
})));
.bind("experimentIds", datasetItemSearchCriteria.experimentIds().toArray(UUID[]::new))
.bind("entityType", datasetItemSearchCriteria.entityType().getType())
.bind("limit", size)
.bind("offset", (page - 1) * size);

bindSearchCriteria(datasetItemSearchCriteria, selectStatement);

return makeFluxContextAware(bindWorkspaceIdToFlux(selectStatement))
.doFinally(signalType -> segment.end())
.flatMap(this::mapItem)
.collectList()
.flatMap(items -> Mono.just(new DatasetItemPage(items, page, items.size(), count)));
});
});
}
}
Loading

0 comments on commit 373194c

Please sign in to comment.