diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java index 5e8b13133b..8e39e14a98 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java @@ -11,8 +11,11 @@ import java.time.Instant; import java.util.List; +import java.util.Map; import java.util.UUID; +import static com.comet.opik.api.ProjectStats.PercentageValues; + @Builder(toBuilder = true) @JsonIgnoreProperties(ignoreUnknown = true) // This annotation is used to specify the strategy to be used for naming of properties for the annotated type. Required so that OpenAPI schema generation uses snake_case @@ -29,7 +32,15 @@ public record Project( @JsonView({Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, @JsonView({Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy, @JsonView({ - Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Instant lastUpdatedTraceAt){ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Instant lastUpdatedTraceAt, + @JsonView({ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable List feedbackScores, + @JsonView({ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PercentageValues duration, + @JsonView({ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Double totalEstimatedCost, + @JsonView({ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Map usage){ public static class View { public static class Write { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java index dce05d0dc8..6e125ad4ee 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java @@ -20,7 +20,6 @@ @ImplementedBy(ProjectMetricsServiceImpl.class) public interface ProjectMetricsService { String ERR_START_BEFORE_END = "'start_time' must be before 'end_time'"; - String ERR_PROJECT_METRIC_NOT_SUPPORTED = "metric '%s' is not supported"; Mono> getProjectMetrics(UUID projectId, ProjectMetricRequest request); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java index 6e25544810..45ab2d6108 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java @@ -5,6 +5,7 @@ import com.comet.opik.api.Project.ProjectPage; import com.comet.opik.api.ProjectCriteria; import com.comet.opik.api.ProjectIdLastUpdated; +import com.comet.opik.api.ProjectStats; import com.comet.opik.api.ProjectUpdate; import com.comet.opik.api.error.EntityAlreadyExistsException; import com.comet.opik.api.error.ErrorMessage; @@ -13,6 +14,7 @@ import com.comet.opik.api.sorting.SortingFactoryProjects; import com.comet.opik.api.sorting.SortingField; import com.comet.opik.domain.sorting.SortingQueryBuilder; +import com.comet.opik.domain.stats.StatsMapper; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.db.TransactionTemplateAsync; import com.comet.opik.utils.PaginationUtils; @@ -44,6 +46,7 @@ import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY; import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; import static java.util.Collections.reverseOrder; +import static java.util.stream.Collectors.toMap; import static java.util.stream.Collectors.toSet; import static java.util.stream.Collectors.toUnmodifiableSet; @@ -199,8 +202,18 @@ public Project get(@NonNull UUID id, @NonNull String workspaceId) { .nonTransaction(connection -> traceDAO.getLastUpdatedTraceAt(Set.of(id), workspaceId, connection)) .block(); + Map> projectStats = getProjectStats(List.of(id), workspaceId); + + return enhanceProject(project, lastUpdatedTraceAt.get(project.id()), projectStats.get(project.id())); + } + + private Project enhanceProject(Project project, Instant lastUpdatedTraceAt, Map projectStats) { return project.toBuilder() - .lastUpdatedTraceAt(lastUpdatedTraceAt.get(id)) + .lastUpdatedTraceAt(lastUpdatedTraceAt) + .feedbackScores(StatsMapper.getStatsFeedbackScores(projectStats)) + .duration(StatsMapper.getStatsDuration(projectStats)) + .totalEstimatedCost(StatsMapper.getStatsTotalEstimatedCost(projectStats)) + .usage(StatsMapper.getStatsUsage(projectStats)) .build(); } @@ -271,17 +284,37 @@ public Page find(int page, int size, @NonNull ProjectCriteria criteria, return traceDAO.getLastUpdatedTraceAt(projectIds, workspaceId, connection); }).block(); + List projectIds = projectRecordSet.content.stream().map(Project::id).toList(); + + Map> projectStats = getProjectStats(projectIds, workspaceId); + List projects = projectRecordSet.content() .stream() - .map(project -> project.toBuilder() - .lastUpdatedTraceAt(projectLastUpdatedTraceAtMap.get(project.id())) - .build()) + .map(project -> enhanceProject(project, projectLastUpdatedTraceAtMap.get(project.id()), + projectStats.get(project.id()))) .toList(); return new ProjectPage(page, projects.size(), projectRecordSet.total(), projects, sortingFactory.getSortableFields()); } + private Map> getProjectStats(List projectIds, String workspaceId) { + return traceDAO.getStatsByProjectIds(projectIds, workspaceId) + .map(stats -> stats.entrySet().stream() + .map(entry -> { + Map statsMap = entry.getValue() + .stats() + .stream() + .collect(toMap(ProjectStats.ProjectStatItem::getName, + ProjectStats.ProjectStatItem::getValue)); + + return Map.entry(entry.getKey(), statsMap); + }) + .map(entry -> Map.entry(entry.getKey(), entry.getValue())) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue))) + .block(); + } + @Override public List findByIds(String workspaceId, Set ids) { if (ids.isEmpty()) { @@ -328,10 +361,12 @@ private Page findWithLastTraceSorting(int page, int size, @NonNull Proj return repository.findByIds(new HashSet<>(finalIds), workspaceId); }).stream().collect(Collectors.toMap(Project::id, Function.identity())); + Map> projectStats = getProjectStats(finalIds, workspaceId); + // compose the final projects list by the correct order and add last trace to it - List projects = finalIds.stream().map(id -> projectsById.get(id).toBuilder() - .lastUpdatedTraceAt(projectLastUpdatedTraceAtMap.get(id)) - .build()) + List projects = finalIds.stream().map(projectsById::get) + .map(project -> enhanceProject(project, projectLastUpdatedTraceAtMap.get(project.id()), + projectStats.get(project.id()))) .toList(); return new ProjectPage(page, projects.size(), allProjectIdsLastUpdated.size(), projects, @@ -403,6 +438,20 @@ public Project retrieveByName(@NonNull String projectName) { return repository.findByNames(workspaceId, List.of(projectName)) .stream() .findFirst() + .map(project -> { + + Map projectLastUpdatedTraceAtMap = transactionTemplateAsync + .nonTransaction(connection -> { + Set projectIds = Set.of(project.id()); + return traceDAO.getLastUpdatedTraceAt(projectIds, workspaceId, connection); + }).block(); + + Map> projectStats = getProjectStats(List.of(project.id()), + workspaceId); + + return enhanceProject(project, projectLastUpdatedTraceAtMap.get(project.id()), + projectStats.get(project.id())); + }) .orElseThrow(this::createNotFoundError); }); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java index 93ec17835f..6a174dce63 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java @@ -86,6 +86,8 @@ interface TraceDAO { Mono getStats(TraceSearchCriteria criteria); Mono getDailyTraces(); + + Mono> getStatsByProjectIds(List projectIds, String workspaceId); } @Slf4j @@ -639,7 +641,7 @@ AND notEquals(start_time, toDateTime64('1970-01-01 00:00:00.000', 9)), if(length(metadata) > 0, 1, 0) as metadata_count, length(tags) as tags_count FROM traces - WHERE project_id = :project_id + WHERE project_id IN :project_ids AND workspace_id = :workspace_id AND @@ -651,7 +653,7 @@ AND id IN ( FROM feedback_scores WHERE entity_type = 'trace' AND workspace_id = :workspace_id - AND project_id = :project_id + AND project_id IN :project_ids ORDER BY entity_id DESC, last_updated_at DESC LIMIT 1 BY entity_id, name ) @@ -675,7 +677,7 @@ AND id IN ( total_estimated_cost FROM spans WHERE workspace_id = :workspace_id - AND project_id = :project_id + AND project_id IN :project_ids ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id ) @@ -699,7 +701,7 @@ LEFT JOIN ( total_estimated_cost FROM spans WHERE workspace_id = :workspace_id - AND project_id = :project_id + AND project_id IN :project_ids ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id ) @@ -722,7 +724,7 @@ LEFT JOIN ( FROM feedback_scores WHERE entity_type = 'trace' AND workspace_id = :workspace_id - AND project_id = :project_id + AND project_id IN :project_ids ORDER BY entity_id DESC, last_updated_at DESC LIMIT 1 BY entity_id, name ) GROUP BY project_id, entity_id @@ -1174,7 +1176,7 @@ public Mono getStats(@NonNull TraceSearchCriteria criteria) { ST statsSQL = newFindTemplate(SELECT_TRACES_STATS, criteria); var statement = connection.createStatement(statsSQL.render()) - .bind("project_id", criteria.projectId()); + .bind("project_ids", List.of(criteria.projectId())); bindSearchCriteria(criteria, statement); @@ -1197,6 +1199,30 @@ public Mono getDailyTraces() { .reduce(0L, Long::sum); } + @Override + public Mono> getStatsByProjectIds(@NonNull List projectIds, + @NonNull String workspaceId) { + + if (projectIds.isEmpty()) { + return Mono.just(Map.of()); + } + + return asyncTemplate + .nonTransaction(connection -> { + Statement statement = connection.createStatement(new ST(SELECT_TRACES_STATS).render()) + .bind("project_ids", projectIds) + .bind("workspace_id", workspaceId); + + return Mono.from(statement.execute()) + .flatMapMany(result -> result.map((row, rowMetadata) -> Map.of( + row.get("project_id", UUID.class), + StatsMapper.mapProjectStats(row, "trace_count")))) + .map(Map::entrySet) + .flatMap(Flux::fromIterable) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + }); + } + @Override @WithSpan public Mono> getLastUpdatedTraceAt( diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/stats/StatsMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/stats/StatsMapper.java index 25ce4d1d93..2bbb24c04d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/stats/StatsMapper.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/stats/StatsMapper.java @@ -1,5 +1,6 @@ package com.comet.opik.domain.stats; +import com.comet.opik.api.FeedbackScoreAverage; import com.comet.opik.api.ProjectStats; import io.r2dbc.spi.Row; @@ -9,48 +10,59 @@ import java.util.Optional; import java.util.stream.Stream; +import static com.comet.opik.api.ProjectStats.AvgValueStat; +import static com.comet.opik.api.ProjectStats.CountValueStat; +import static com.comet.opik.api.ProjectStats.PercentageValueStat; +import static com.comet.opik.api.ProjectStats.PercentageValues; +import static java.util.stream.Collectors.toMap; + public class StatsMapper { + public static final String USAGE = "usage"; + public static final String FEEDBACK_SCORE = "feedback_scores"; + public static final String TOTAL_ESTIMATED_COST = "total_estimated_cost"; + public static final String DURATION = "duration"; + public static ProjectStats mapProjectStats(Row row, String entityCountLabel) { var stats = Stream.>builder() - .add(new ProjectStats.CountValueStat(entityCountLabel, + .add(new CountValueStat(entityCountLabel, row.get(entityCountLabel, Long.class))) - .add(new ProjectStats.PercentageValueStat("duration", Optional - .ofNullable(row.get("duration", List.class)) - .map(durations -> new ProjectStats.PercentageValues( + .add(new PercentageValueStat(DURATION, Optional + .ofNullable(row.get(DURATION, List.class)) + .map(durations -> new PercentageValues( getP(durations, 0), getP(durations, 1), getP(durations, 2))) .orElse(null))) - .add(new ProjectStats.CountValueStat("input", row.get("input", Long.class))) - .add(new ProjectStats.CountValueStat("output", row.get("output", Long.class))) - .add(new ProjectStats.CountValueStat("metadata", row.get("metadata", Long.class))) - .add(new ProjectStats.AvgValueStat("tags", row.get("tags", Double.class))); + .add(new CountValueStat("input", row.get("input", Long.class))) + .add(new CountValueStat("output", row.get("output", Long.class))) + .add(new CountValueStat("metadata", row.get("metadata", Long.class))) + .add(new AvgValueStat("tags", row.get("tags", Double.class))); BigDecimal totalEstimatedCost = row.get("total_estimated_cost_avg", BigDecimal.class); if (totalEstimatedCost == null) { totalEstimatedCost = BigDecimal.ZERO; } - stats.add(new ProjectStats.AvgValueStat("total_estimated_cost", totalEstimatedCost.doubleValue())); + stats.add(new AvgValueStat(TOTAL_ESTIMATED_COST, totalEstimatedCost.doubleValue())); - Map usage = row.get("usage", Map.class); - Map feedbackScores = row.get("feedback_scores", Map.class); + Map usage = row.get(USAGE, Map.class); + Map feedbackScores = row.get(FEEDBACK_SCORE, Map.class); if (usage != null) { usage.keySet() .stream() .sorted() .forEach(key -> stats - .add(new ProjectStats.AvgValueStat("%s.%s".formatted("usage", key), usage.get(key)))); + .add(new AvgValueStat("%s.%s".formatted(USAGE, key), usage.get(key)))); } if (feedbackScores != null) { feedbackScores.keySet() .stream() .sorted() - .forEach(key -> stats.add(new ProjectStats.AvgValueStat("%s.%s".formatted("feedback_score", key), + .forEach(key -> stats.add(new AvgValueStat("%s.%s".formatted(FEEDBACK_SCORE, key), feedbackScores.get(key)))); } @@ -61,4 +73,38 @@ private static BigDecimal getP(List durations, int index) { return durations.get(index); } + public static Map getStatsUsage(Map stats) { + return Optional.ofNullable(stats) + .map(map -> map.keySet() + .stream() + .filter(k -> k.startsWith(USAGE)) + .map( + k -> Map.entry(k.substring("%s.".formatted(USAGE).length()), (Double) map.get(k))) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue))) + .orElse(null); + } + + public static Double getStatsTotalEstimatedCost(Map stats) { + return Optional.ofNullable(stats) + .map(map -> map.get(TOTAL_ESTIMATED_COST)) + .map(v -> (Double) v) + .orElse(null); + } + + public static List getStatsFeedbackScores(Map stats) { + return Optional.ofNullable(stats) + .map(map -> map.keySet() + .stream() + .filter(k -> k.startsWith(FEEDBACK_SCORE)) + .map(k -> new FeedbackScoreAverage(k.substring("%s.".formatted(FEEDBACK_SCORE).length()), + BigDecimal.valueOf((Double) map.get(k)))) + .toList()) + .orElse(null); + } + + public static PercentageValues getStatsDuration(Map stats) { + return Optional.ofNullable(stats) + .map(map -> (PercentageValues) map.get(DURATION)) + .orElse(null); + } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/BigDecimalCollectors.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/BigDecimalCollectors.java new file mode 100644 index 0000000000..c34e7250e5 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/BigDecimalCollectors.java @@ -0,0 +1,30 @@ +package com.comet.opik.api.resources.utils; + +import com.comet.opik.utils.ValidationUtils; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.function.Function; +import java.util.stream.Collector; + +public class BigDecimalCollectors { + public static Collector averagingBigDecimal(Function mapper) { + return Collector.of( + () -> new BigDecimal[]{BigDecimal.ZERO, BigDecimal.ZERO}, + (result, value) -> { + BigDecimal mappedValue = mapper.apply(value); + if (mappedValue != null) { + result[0] = result[0].add(mappedValue); + result[1] = result[1].add(BigDecimal.ONE); + } + }, + (result1, result2) -> { + result1[0] = result1[0].add(result2[0]); + result1[1] = result1[1].add(result2[1]); + return result1; + }, + result -> result[1].compareTo(BigDecimal.ZERO) == 0 + ? BigDecimal.ZERO + : result[0].divide(result[1], ValidationUtils.SCALE, RoundingMode.HALF_UP)); + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/StatsUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/StatsUtils.java index cd1123e7d4..c120fad665 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/StatsUtils.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/StatsUtils.java @@ -17,6 +17,7 @@ import java.math.RoundingMode; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; @@ -318,4 +319,50 @@ public static int bigDecimalComparator(BigDecimal v1, BigDecimal v2) { return strippedV1.toBigInteger().compareTo(strippedV2.toBigInteger()); } + public static int closeToEpsilonComparator(Object v1, Object v2) { + //TODO This is a workaround to compare averages originating from BigDecimals calculated by code vs. the same + // calculated by Clickhouse + + // Handle null cases (if nulls are allowed) + if (v1 == null && v2 == null) { + return 0; // Both null are considered equal + } else if (v1 == null) { + return -1; // Null is considered "less than" + } else if (v2 == null) { + return 1; // Non-null is considered "greater than" + } + + if (v1.equals(v2)) { + return 0; + } + + Number numv1 = (Number) v1, numv2 = (Number) v2; + + // Define an absolute tolerance for comparison + double epsilon = .00001; + + // Calculate the absolute difference + double difference = Math.abs(numv1.doubleValue() - numv2.doubleValue()); + + // If the difference is within the tolerance, consider them equal + if (difference <= epsilon) { + return 0; + } + + // otherwise return ordinary comparison + return 1; + } + + public static Map aggregateSpansUsage(List spans) { + return spans.stream() + .flatMap(span -> span.usage().entrySet().stream()) + .map(entry -> new AbstractMap.SimpleEntry<>(entry.getKey(), Long.valueOf(entry.getValue()))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, Long::sum)); + } + + public static BigDecimal aggregateSpansCost(List spans) { + return spans.stream() + .map(span -> ModelPrice.fromString(span.model()).calculateCost(span.usage())) + .reduce(BigDecimal.ZERO, BigDecimal::add); + } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java index 6b85fc99ab..7c093bd7cc 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java @@ -2,21 +2,30 @@ import com.comet.opik.TestComparators; import com.comet.opik.api.BatchDelete; +import com.comet.opik.api.FeedbackScore; +import com.comet.opik.api.FeedbackScoreAverage; +import com.comet.opik.api.FeedbackScoreBatchItem; import com.comet.opik.api.Project; import com.comet.opik.api.ProjectRetrieve; +import com.comet.opik.api.ProjectStats; import com.comet.opik.api.ProjectUpdate; +import com.comet.opik.api.Span; import com.comet.opik.api.Trace; import com.comet.opik.api.TraceUpdate; import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.BigDecimalCollectors; import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.DurationUtils; import com.comet.opik.api.resources.utils.MigrationUtils; import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.StatsUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; import com.comet.opik.api.resources.utils.TestUtils; import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.api.resources.utils.resources.SpanResourceClient; import com.comet.opik.api.resources.utils.resources.TraceResourceClient; import com.comet.opik.api.sorting.Direction; import com.comet.opik.api.sorting.SortableFields; @@ -26,13 +35,14 @@ import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; import com.comet.opik.podam.PodamFactoryUtils; import com.comet.opik.utils.JsonUtils; +import com.comet.opik.utils.ValidationUtils; import com.github.tomakehurst.wiremock.client.WireMock; import com.redis.testcontainers.RedisContainer; import jakarta.ws.rs.HttpMethod; import jakarta.ws.rs.client.Entity; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; -import org.apache.hc.core5.http.HttpStatus; +import org.apache.http.HttpStatus; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -54,13 +64,19 @@ import ru.vyarus.dropwizard.guice.test.ClientSupport; import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import uk.co.jemos.podam.api.PodamFactory; +import uk.co.jemos.podam.api.PodamUtils; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.sql.SQLException; import java.time.Instant; +import java.util.Comparator; import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.UUID; import java.util.regex.Pattern; @@ -80,6 +96,9 @@ import static com.github.tomakehurst.wiremock.client.WireMock.okJson; import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static java.util.stream.Collectors.averagingDouble; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.toMap; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.params.provider.Arguments.arguments; @@ -91,7 +110,7 @@ class ProjectsResourceTest { public static final String URL_TEMPLATE = "%s/v1/private/projects"; public static final String URL_TEMPLATE_TRACE = "%s/v1/private/traces"; public static final String[] IGNORED_FIELDS = {"createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", - "lastUpdatedTraceAt"}; + "lastUpdatedTraceAt", "feedbackScores", "duration", "totalEstimatedCost", "usage"}; private static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); @@ -125,6 +144,7 @@ class ProjectsResourceTest { private ClientSupport client; private ProjectService projectService; private TraceResourceClient traceResourceClient; + private SpanResourceClient spanResourceClient; @BeforeAll void setUpAll(ClientSupport client, Jdbi jdbi, ProjectService projectService) throws SQLException { @@ -145,6 +165,7 @@ void setUpAll(ClientSupport client, Jdbi jdbi, ProjectService projectService) th mockTargetWorkspace(API_KEY, TEST_WORKSPACE, WORKSPACE_ID); this.traceResourceClient = new TraceResourceClient(this.client, baseURI); + this.spanResourceClient = new SpanResourceClient(this.client, baseURI); } private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { @@ -1094,6 +1115,252 @@ void getProjects__whenProjectsHasTraces__thenReturnProjectWithLastUpdatedTraceAt expectedProject3)); } + @Test + @DisplayName("when projects with traces, spans, feedback scores, and usage, then return project aggregations") + void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnProjectAggregations() { + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) + .parallelStream() + .map(project -> project.toBuilder() + .id(createProject(project, apiKey, workspaceName)) + .totalEstimatedCost(null) + .usage(null) + .feedbackScores(null) + .duration(null) + .build()) + .toList(); + + List expectedProjects = projects.parallelStream() + .map(project -> buildProjectStats(project, apiKey, workspaceName)) + .sorted(Comparator.comparing(Project::id).reversed()) + .toList(); + + var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get(); + + var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); + + assertThat(expectedProjects).hasSameSizeAs(actualEntity.content()); + + assertThat(actualEntity.content()) + .usingRecursiveComparison() + .ignoringFields("createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", "lastUpdatedTraceAt") + .ignoringCollectionOrder() + .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) + .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") + .isEqualTo(expectedProjects); + } + + @Test + @DisplayName("when projects with traces, spans, feedback scores, and usage and sorted by last updated trace at, then return project aggregations") + void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsageSortedLastTrace__thenReturnProjectAggregations() { + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) + .parallelStream() + .map(project -> project.toBuilder() + .id(createProject(project, apiKey, workspaceName)) + .totalEstimatedCost(null) + .usage(null) + .feedbackScores(null) + .duration(null) + .build()) + .toList(); + + List expectedProjects = projects.parallelStream() + .map(project -> buildProjectStats(project, apiKey, workspaceName)) + .sorted(Comparator.comparing(Project::id).reversed()) + .toList(); + + var sorting = List.of(SortingField.builder() + .field(SortableFields.LAST_UPDATED_TRACE_AT) + .direction(Direction.DESC) + .build()); + + var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .queryParam("sorting", URLEncoder.encode(JsonUtils.writeValueAsString(sorting), + StandardCharsets.UTF_8)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get(); + + var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); + + assertThat(expectedProjects).hasSameSizeAs(actualEntity.content()); + + assertThat(actualEntity.content()) + .usingRecursiveComparison() + .ignoringFields("createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", "lastUpdatedTraceAt") + .ignoringCollectionOrder() + .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) + .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") + .isEqualTo(expectedProjects); + } + + @Test + @DisplayName("when projects without traces, spans, feedback scores, and usage, then return project aggregations") + void getProjects__whenProjectsHasNoTracesSpansFeedbackScoresAndUsage__thenReturnProjectAggregations() { + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) + .parallelStream() + .map(project -> project.toBuilder() + .id(createProject(project, apiKey, workspaceName)) + .totalEstimatedCost(null) + .usage(null) + .feedbackScores(null) + .duration(null) + .build()) + .toList(); + + List expectedProjects = projects.parallelStream() + .map(project -> project.toBuilder() + .duration(null) + .totalEstimatedCost(null) + .usage(null) + .feedbackScores(null) + .build()) + .sorted(Comparator.comparing(Project::id).reversed()) + .toList(); + + var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get(); + + var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); + + assertThat(expectedProjects).hasSameSizeAs(actualEntity.content()); + + assertThat(actualEntity.content()) + .usingRecursiveComparison() + .ignoringFields("createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", "lastUpdatedTraceAt") + .ignoringCollectionOrder() + .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) + .withComparatorForFields(StatsUtils::closeToEpsilonComparator, "totalEstimatedCost") + .isEqualTo(expectedProjects); + } + + private Project buildProjectStats(Project project, String apiKey, String workspaceName) { + var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() + .map(trace -> { + Instant startTime = Instant.now(); + Instant endTime = startTime.plusMillis(PodamUtils.getIntegerInRange(1, 1000)); + return trace.toBuilder() + .projectName(project.name()) + .startTime(startTime) + .endTime(endTime) + .duration(DurationUtils.getDurationInMillisWithSubMilliPrecision(startTime, endTime)) + .build(); + }) + .toList(); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); + + List scores = PodamFactoryUtils.manufacturePojoList(factory, + FeedbackScoreBatchItem.class); + + traces = traces.stream().map(trace -> { + List spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() + .map(span -> span.toBuilder() + .usage(spanResourceClient.getTokenUsage()) + .model(spanResourceClient.randomModelPrice().getName()) + .traceId(trace.id()) + .projectName(trace.projectName()) + .build()) + .toList(); + + spanResourceClient.batchCreateSpans(spans, apiKey, workspaceName); + + List feedbackScores = scores.stream() + .map(feedbackScore -> feedbackScore.toBuilder() + .projectId(project.id()) + .projectName(project.name()) + .id(trace.id()) + .build()) + .toList(); + + traceResourceClient.feedbackScores(feedbackScores, apiKey, workspaceName); + + return trace.toBuilder() + .feedbackScores( + feedbackScores.stream() + .map(score -> FeedbackScore.builder() + .value(score.value()) + .name(score.name()) + .build()) + .toList()) + .usage(StatsUtils.aggregateSpansUsage(spans)) + .totalEstimatedCost(StatsUtils.aggregateSpansCost(spans)) + .build(); + }).toList(); + + List durations = StatsUtils.calculateQuantiles( + traces.stream() + .map(Trace::duration) + .toList(), + List.of(0.5, 0.90, 0.99)); + + return project.toBuilder() + .duration(new ProjectStats.PercentageValues(durations.get(0), durations.get(1), durations.get(2))) + .totalEstimatedCost(getTotalEstimatedCost(traces)) + .usage(traces.stream() + .map(Trace::usage) + .flatMap(usage -> usage.entrySet().stream()) + .collect(groupingBy(Map.Entry::getKey, averagingDouble(Map.Entry::getValue)))) + .feedbackScores(getScoreAverages(traces)) + .build(); + } + + private List getScoreAverages(List traces) { + return traces.stream() + .map(Trace::feedbackScores) + .flatMap(List::stream) + .collect(groupingBy(FeedbackScore::name, + BigDecimalCollectors.averagingBigDecimal(FeedbackScore::value))) + .entrySet() + .stream() + .map(entry -> FeedbackScoreAverage.builder() + .name(entry.getKey()) + .value(entry.getValue()) + .build()) + .toList(); + } + + private double getTotalEstimatedCost(List traces) { + long count = traces.stream() + .map(Trace::totalEstimatedCost) + .filter(Objects::nonNull) + .filter(cost -> cost.compareTo(BigDecimal.ZERO) > 0) + .count(); + + return traces.stream() + .map(Trace::totalEstimatedCost) + .reduce(BigDecimal.ZERO, BigDecimal::add) + .divide(BigDecimal.valueOf(count), ValidationUtils.SCALE, RoundingMode.HALF_UP).doubleValue(); + } + @Test @DisplayName("when projects is with traces created in batch, then return project with last updated trace at") void getProjects__whenProjectsHasTracesBatch__thenReturnProjectWithLastUpdatedTraceAt() { @@ -1194,17 +1461,19 @@ void getProjects__whenTraceIsUpdated__thenUpdateProjectsLastUpdatedTraceAt() { } private void assertAllProjectsHavePersistedLastTraceAt(String workspaceId, List expectedProjects) { - List dbProjects = projectService.findByIds(workspaceId, expectedProjects.stream() - .map(Project::id).collect(Collectors.toUnmodifiableSet())); - - for (Project project : expectedProjects) { - Awaitility.await().untilAsserted(() -> { - assertThat(dbProjects.stream().filter(dbProject -> dbProject.id().equals(project.id())) - .findFirst().orElseThrow().lastUpdatedTraceAt()) - .usingComparator(TestComparators::compareMicroNanoTime) - .isEqualTo(project.lastUpdatedTraceAt()); - }); - } + Awaitility.await().untilAsserted(() -> { + List dbProjects = projectService.findByIds(workspaceId, expectedProjects.stream() + .map(Project::id).collect(Collectors.toUnmodifiableSet())); + Map actualLastTraceByProjectId = dbProjects.stream() + .collect(toMap(Project::id, Project::lastUpdatedTraceAt)); + Map expectedLastTraceByProjectId = expectedProjects.stream() + .collect(toMap(Project::id, Project::lastUpdatedTraceAt)); + + assertThat(actualLastTraceByProjectId) + .usingRecursiveComparison() + .withComparatorForType(TestComparators::compareMicroNanoTime, Instant.class) + .isEqualTo(expectedLastTraceByProjectId); + }); } } @@ -1274,17 +1543,8 @@ private UUID createCreateTrace(String projectName, String apiKey, String workspa .projectName(projectName) .build(); - try (var actualResponse = client.target(URL_TEMPLATE_TRACE.formatted(baseURI)) - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .post(Entity.json(trace))) { - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - assertThat(actualResponse.hasEntity()).isFalse(); - - return TestUtils.getIdFromLocation(actualResponse.getLocation()); - } + traceResourceClient.batchCreateTraces(List.of(trace), apiKey, workspaceName); + return trace.id(); } private Trace getTrace(UUID id, String apiKey, String workspaceName) {