Skip to content

Commit

Permalink
OPIK-645 Return all feedback score names for project ids endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
Borys Tkachenko authored and Borys Tkachenko committed Dec 20, 2024
1 parent 7cae567 commit ea0ad3e
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import com.comet.opik.api.PageColumns;
import com.comet.opik.api.filter.ExperimentsComparisonFilter;
import com.comet.opik.api.filter.FiltersFactory;
import com.comet.opik.api.resources.v1.priv.validate.ExperimentParamsValidator;
import com.comet.opik.api.resources.v1.priv.validate.IdParamsValidator;
import com.comet.opik.api.sorting.SortingFactoryDatasets;
import com.comet.opik.api.sorting.SortingField;
import com.comet.opik.domain.DatasetItemService;
Expand Down Expand Up @@ -375,7 +375,7 @@ public Response findDatasetItemsWithExperimentItems(
@QueryParam("filters") String filters,
@QueryParam("truncate") boolean truncate) {

var experimentIds = ExperimentParamsValidator.getExperimentIds(experimentIdsQueryParam);
var experimentIds = IdParamsValidator.getIds(experimentIdsQueryParam);

var queryFilters = filtersFactory.newFilters(filters, ExperimentsComparisonFilter.LIST_TYPE_REFERENCE);

Expand Down Expand Up @@ -413,7 +413,7 @@ public Response getDatasetItemsOutputColumns(

var experimentIds = Optional.ofNullable(experimentIdsQueryParam)
.filter(Predicate.not(String::isEmpty))
.map(ExperimentParamsValidator::getExperimentIds)
.map(IdParamsValidator::getIds)
.orElse(null);

String workspaceId = requestContext.get().getWorkspaceId();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import com.comet.opik.api.FeedbackDefinition;
import com.comet.opik.api.FeedbackScoreNames;
import com.comet.opik.api.Identifier;
import com.comet.opik.api.resources.v1.priv.validate.ExperimentParamsValidator;
import com.comet.opik.api.resources.v1.priv.validate.IdParamsValidator;
import com.comet.opik.domain.ExperimentItemService;
import com.comet.opik.domain.ExperimentService;
import com.comet.opik.domain.FeedbackScoreService;
Expand Down Expand Up @@ -281,7 +281,7 @@ public Response deleteExperimentItems(
public Response findFeedbackScoreNames(@QueryParam("experiment_ids") String experimentIdsQueryParam) {

var experimentIds = Optional.ofNullable(experimentIdsQueryParam)
.map(ExperimentParamsValidator::getExperimentIds)
.map(IdParamsValidator::getIds)
.map(List::copyOf)
.orElse(null);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.codahale.metrics.annotation.Timed;
import com.comet.opik.api.BatchDelete;
import com.comet.opik.api.FeedbackScoreNames;
import com.comet.opik.api.Page;
import com.comet.opik.api.Project;
import com.comet.opik.api.ProjectCriteria;
Expand All @@ -10,8 +11,10 @@
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.api.metrics.ProjectMetricRequest;
import com.comet.opik.api.metrics.ProjectMetricResponse;
import com.comet.opik.api.resources.v1.priv.validate.IdParamsValidator;
import com.comet.opik.api.sorting.SortingFactoryProjects;
import com.comet.opik.api.sorting.SortingField;
import com.comet.opik.domain.FeedbackScoreService;
import com.comet.opik.domain.ProjectMetricsService;
import com.comet.opik.domain.ProjectService;
import com.comet.opik.infrastructure.auth.RequestContext;
Expand Down Expand Up @@ -49,6 +52,7 @@
import lombok.extern.slf4j.Slf4j;

import java.util.List;
import java.util.Optional;
import java.util.UUID;

import static com.comet.opik.domain.ProjectMetricsService.ERR_START_BEFORE_END;
Expand All @@ -67,6 +71,7 @@ public class ProjectsResource {
private final @NonNull Provider<RequestContext> requestContext;
private final @NonNull SortingFactoryProjects sortingFactory;
private final @NonNull ProjectMetricsService metricsService;
private final @NonNull FeedbackScoreService feedbackScoreService;

@GET
@Operation(operationId = "findProjects", summary = "Find projects", description = "Find projects", responses = {
Expand Down Expand Up @@ -232,6 +237,32 @@ public Response getProjectMetrics(
return Response.ok().entity(response).build();
}

@GET
@Path("/feedback-scores/names")
@Operation(operationId = "findFeedbackScoreNamesByProjectIds", summary = "Find Feedback Score names By Project Ids", description = "Find Feedback Score names By Project Ids", responses = {
@ApiResponse(responseCode = "200", description = "Feedback Scores resource", content = @Content(schema = @Schema(implementation = FeedbackScoreNames.class)))
})
public Response findFeedbackScoreNames(@QueryParam("project_ids") String projectIdsQueryParam) {

var projectIds = Optional.ofNullable(projectIdsQueryParam)
.map(IdParamsValidator::getIds)
.map(List::copyOf)
.orElse(null);

String workspaceId = requestContext.get().getWorkspaceId();

log.info("Find feedback score names by project_ids '{}', on workspaceId '{}'",
projectIds, workspaceId);
FeedbackScoreNames feedbackScoreNames = feedbackScoreService
.getProjectsFeedbackScoreNames(projectIds)
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();
log.info("Found feedback score names '{}' by project_ids '{}', on workspaceId '{}'",
feedbackScoreNames.scores().size(), projectIds, workspaceId);

return Response.ok(feedbackScoreNames).build();
}

private void validate(ProjectMetricRequest request) {
if (!request.intervalStart().isBefore(request.intervalEnd())) {
throw new BadRequestException(ERR_START_BEFORE_END);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@

@UtilityClass
@Slf4j
public class ExperimentParamsValidator {
public class IdParamsValidator {

private static final TypeReference<List<UUID>> LIST_UUID_TYPE_REFERENCE = new TypeReference<>() {
};

public static Set<UUID> getExperimentIds(String experimentIds) {
var message = "Invalid query param experiment ids '%s'".formatted(experimentIds);
public static Set<UUID> getIds(String idsQueryParam) {
var message = "Invalid query param ids '%s'".formatted(idsQueryParam);
try {
return JsonUtils.readValue(experimentIds, LIST_UUID_TYPE_REFERENCE)
return JsonUtils.readValue(idsQueryParam, LIST_UUID_TYPE_REFERENCE)
.stream()
.collect(Collectors.toUnmodifiableSet());
} catch (RuntimeException exception) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ Mono<Long> scoreEntity(EntityType entityType, UUID entityId, FeedbackScore score
Mono<List<String>> getSpanFeedbackScoreNames(@NonNull UUID projectId, SpanType type);

Mono<List<String>> getExperimentsFeedbackScoreNames(List<UUID> experimentIds);

Mono<List<String>> getProjectsFeedbackScoreNames(List<UUID> projectIds);
}

@Singleton
Expand Down Expand Up @@ -203,6 +205,23 @@ INNER JOIN (
;
""";

private static final String SELECT_PROJECTS_FEEDBACK_SCORE_NAMES = """
SELECT
distinct name
FROM (
SELECT
name
FROM feedback_scores
WHERE workspace_id = :workspace_id
<if(project_ids)>
AND project_id IN :project_ids
<endif>
ORDER BY entity_id DESC, last_updated_at DESC
LIMIT 1 BY entity_id, name
) AS names
;
""";

private final static String SELECT_SPAN_FEEDBACK_SCORE_NAMES = """
SELECT
distinct name
Expand Down Expand Up @@ -433,6 +452,34 @@ public Mono<List<String>> getExperimentsFeedbackScoreNames(List<UUID> experiment
});
}

@Override
@WithSpan
public Mono<List<String>> getProjectsFeedbackScoreNames(List<UUID> projectIds) {
return asyncTemplate.nonTransaction(connection -> {

ST template = new ST(SELECT_PROJECTS_FEEDBACK_SCORE_NAMES);

if (CollectionUtils.isNotEmpty(projectIds)) {
template.add("project_ids", projectIds);
}

var statement = connection.createStatement(template.render());

if (CollectionUtils.isNotEmpty(projectIds)) {
template.add("project_ids", projectIds);
}

if (CollectionUtils.isNotEmpty(projectIds)) {
statement.bind("project_ids", projectIds.toArray(UUID[]::new));
}

return makeMonoContextAware(bindWorkspaceIdToMono(statement))
.flatMapMany(result -> result.map((row, rowMetadata) -> row.get("name", String.class)))
.distinct()
.collect(Collectors.toList());
});
}

@Override
@WithSpan
public Mono<List<String>> getSpanFeedbackScoreNames(@NonNull UUID projectId, SpanType type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ public interface FeedbackScoreService {
Mono<FeedbackScoreNames> getSpanFeedbackScoreNames(UUID projectId, SpanType type);

Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(List<UUID> experimentIds);

Mono<FeedbackScoreNames> getProjectsFeedbackScoreNames(List<UUID> projectIds);
}

@Slf4j
Expand Down Expand Up @@ -256,6 +258,13 @@ public Mono<FeedbackScoreNames> getExperimentsFeedbackScoreNames(List<UUID> expe
.map(FeedbackScoreNames::new);
}

@Override
public Mono<FeedbackScoreNames> getProjectsFeedbackScoreNames(List<UUID> projectIds) {
return dao.getProjectsFeedbackScoreNames(projectIds)
.map(names -> names.stream().map(FeedbackScoreNames.ScoreName::new).toList())
.map(FeedbackScoreNames::new);
}

private Mono<Long> failWithNotFound(String errorMessage) {
log.info(errorMessage);
return Mono.error(new NotFoundException(Response.status(404)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package com.comet.opik.api.resources.utils.resources;

import com.comet.opik.api.FeedbackScoreNames;
import com.comet.opik.api.Project;
import com.comet.opik.api.resources.utils.TestUtils;
import com.comet.opik.infrastructure.auth.RequestContext;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.client.WebTarget;
import jakarta.ws.rs.core.HttpHeaders;
import lombok.RequiredArgsConstructor;
import org.apache.hc.core5.http.HttpStatus;
Expand Down Expand Up @@ -74,4 +76,25 @@ public Project getByName(String projectName, String apiKey, String workspaceName
}
}

public FeedbackScoreNames findFeedbackScoreNames(String projectIdsQueryParam, String apiKey, String workspaceName) {
WebTarget webTarget = client.target(RESOURCE_PATH.formatted(baseURI))
.path("feedback-scores")
.path("names");

if (projectIdsQueryParam != null) {
webTarget = webTarget.queryParam("project_ids", projectIdsQueryParam);
}

try (var actualResponse = webTarget
.request()
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(RequestContext.WORKSPACE_HEADER, workspaceName)
.get()) {

// then
assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK);

return actualResponse.readEntity(FeedbackScoreNames.class);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2659,20 +2659,22 @@ void getFeedbackScoreNames__whenGetFeedbackScoreNames__thenReturnFeedbackScoreNa

createExperimentsItems(apiKey, workspaceName, unexpectedScores, List.of());

fetchAndAssertResponse(userExperimentId, experimentId, names, otherNames, apiKey, workspaceName);
fetchAndAssertResponse(userExperimentId, experimentId, projectId, names, otherNames, apiKey, workspaceName);
}
}

private void fetchAndAssertResponse(boolean userExperimentId, UUID experimentId, List<String> names,
private void fetchAndAssertResponse(boolean userExperimentId, UUID experimentId, UUID projectId, List<String> names,
List<String> otherNames, String apiKey, String workspaceName) {

WebTarget webTarget = client.target(URL_TEMPLATE.formatted(baseURI))
.path("feedback-scores")
.path("names");

String projectIdsQueryParam = null;
if (userExperimentId) {
var ids = JsonUtils.writeValueAsString(List.of(experimentId));
webTarget = webTarget.queryParam("experiment_ids", ids);
projectIdsQueryParam = JsonUtils.writeValueAsString(List.of(projectId));
}

List<String> expectedNames = userExperimentId
Expand All @@ -2688,14 +2690,20 @@ private void fetchAndAssertResponse(boolean userExperimentId, UUID experimentId,
// then
assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_OK);
var actualEntity = actualResponse.readEntity(FeedbackScoreNames.class);

assertThat(actualEntity.scores()).hasSize(expectedNames.size());
assertThat(actualEntity
.scores()
.stream()
.map(FeedbackScoreNames.ScoreName::name)
.toList()).containsExactlyInAnyOrderElementsOf(expectedNames);
assertFeedbackScoreNames(actualEntity, expectedNames);
}

var feedbackScoreNamesByProjectId = projectResourceClient.findFeedbackScoreNames(projectIdsQueryParam, apiKey, workspaceName);
assertFeedbackScoreNames(feedbackScoreNamesByProjectId, expectedNames);
}

private void assertFeedbackScoreNames(FeedbackScoreNames actual, List<String> expectedNames) {
assertThat(actual.scores()).hasSize(expectedNames.size());
assertThat(actual
.scores()
.stream()
.map(FeedbackScoreNames.ScoreName::name)
.toList()).containsExactlyInAnyOrderElementsOf(expectedNames);
}

private List<List<FeedbackScoreBatchItem>> createMultiValueScores(List<String> multipleValuesFeedbackScores,
Expand Down

0 comments on commit ea0ad3e

Please sign in to comment.