From fb94fbd2f4d52fc626d3c70646a1cb5ba01090eb Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Sat, 14 Dec 2024 16:51:51 +0100 Subject: [PATCH] [OPIK-287] Add project level aggregations --- .../main/java/com/comet/opik/api/Project.java | 13 +- .../com/comet/opik/domain/ProjectService.java | 54 +++- .../java/com/comet/opik/domain/TraceDAO.java | 38 ++- .../comet/opik/domain/stats/StatsMapper.java | 72 ++++- .../resources/utils/BigDecimalCollectors.java | 53 ++++ .../utils/resources/TraceResourceClient.java | 19 ++ .../v1/priv/ProjectsResourceTest.java | 280 +++++++++++++++--- 7 files changed, 472 insertions(+), 57 deletions(-) create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/BigDecimalCollectors.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java index 5e8b13133b..8e39e14a98 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Project.java @@ -11,8 +11,11 @@ import java.time.Instant; import java.util.List; +import java.util.Map; import java.util.UUID; +import static com.comet.opik.api.ProjectStats.PercentageValues; + @Builder(toBuilder = true) @JsonIgnoreProperties(ignoreUnknown = true) // This annotation is used to specify the strategy to be used for naming of properties for the annotated type. Required so that OpenAPI schema generation uses snake_case @@ -29,7 +32,15 @@ public record Project( @JsonView({Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, @JsonView({Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy, @JsonView({ - Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Instant lastUpdatedTraceAt){ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Instant lastUpdatedTraceAt, + @JsonView({ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable List feedbackScores, + @JsonView({ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PercentageValues duration, + @JsonView({ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Double totalEstimatedCost, + @JsonView({ + Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Map usage){ public static class View { public static class Write { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java index 6e25544810..109096c306 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectService.java @@ -5,6 +5,7 @@ import com.comet.opik.api.Project.ProjectPage; import com.comet.opik.api.ProjectCriteria; import com.comet.opik.api.ProjectIdLastUpdated; +import com.comet.opik.api.ProjectStats; import com.comet.opik.api.ProjectUpdate; import com.comet.opik.api.error.EntityAlreadyExistsException; import com.comet.opik.api.error.ErrorMessage; @@ -13,6 +14,7 @@ import com.comet.opik.api.sorting.SortingFactoryProjects; import com.comet.opik.api.sorting.SortingField; import com.comet.opik.domain.sorting.SortingQueryBuilder; +import com.comet.opik.domain.stats.StatsMapper; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.db.TransactionTemplateAsync; import com.comet.opik.utils.PaginationUtils; @@ -44,6 +46,7 @@ import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY; import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; import static java.util.Collections.reverseOrder; +import static java.util.stream.Collectors.toMap; import static java.util.stream.Collectors.toSet; import static java.util.stream.Collectors.toUnmodifiableSet; @@ -199,8 +202,19 @@ public Project get(@NonNull UUID id, @NonNull String workspaceId) { .nonTransaction(connection -> traceDAO.getLastUpdatedTraceAt(Set.of(id), workspaceId, connection)) .block(); + Map> projectStats = getProjectStats(List.of(id), workspaceId); + + return enhanceProject(project, lastUpdatedTraceAt, projectStats); + } + + private Project enhanceProject(Project project, Map lastUpdatedTraceAt, + Map> projectStats) { return project.toBuilder() - .lastUpdatedTraceAt(lastUpdatedTraceAt.get(id)) + .lastUpdatedTraceAt(lastUpdatedTraceAt.get(project.id())) + .feedbackScores(StatsMapper.getStatsFeedbackScores(projectStats.get(project.id()))) + .duration(StatsMapper.getStatsDuration(projectStats.get(project.id()))) + .totalEstimatedCost(StatsMapper.getStatsTotalEstimatedCost(projectStats.get(project.id()))) + .usage(StatsMapper.getStatsUsage(projectStats.get(project.id()))) .build(); } @@ -271,17 +285,36 @@ public Page find(int page, int size, @NonNull ProjectCriteria criteria, return traceDAO.getLastUpdatedTraceAt(projectIds, workspaceId, connection); }).block(); + List projectIds = projectRecordSet.content.stream().map(Project::id).toList(); + + Map> projectStats = getProjectStats(projectIds, workspaceId); + List projects = projectRecordSet.content() .stream() - .map(project -> project.toBuilder() - .lastUpdatedTraceAt(projectLastUpdatedTraceAtMap.get(project.id())) - .build()) + .map(project -> enhanceProject(project, projectLastUpdatedTraceAtMap, projectStats)) .toList(); return new ProjectPage(page, projects.size(), projectRecordSet.total(), projects, sortingFactory.getSortableFields()); } + private Map> getProjectStats(List projectIds, String workspaceId) { + return traceDAO.getStatsByProjectIds(projectIds, workspaceId) + .map(stats -> stats.entrySet().stream() + .map(entry -> { + Map statsMap = entry.getValue() + .stats() + .stream() + .collect(toMap(ProjectStats.ProjectStatItem::getName, + ProjectStats.ProjectStatItem::getValue)); + + return Map.entry(entry.getKey(), statsMap); + }) + .map(entry -> Map.entry(entry.getKey(), entry.getValue())) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue))) + .block(); + } + @Override public List findByIds(String workspaceId, Set ids) { if (ids.isEmpty()) { @@ -403,6 +436,19 @@ public Project retrieveByName(@NonNull String projectName) { return repository.findByNames(workspaceId, List.of(projectName)) .stream() .findFirst() + .map(project -> { + + Map projectLastUpdatedTraceAtMap = transactionTemplateAsync + .nonTransaction(connection -> { + Set projectIds = Set.of(project.id()); + return traceDAO.getLastUpdatedTraceAt(projectIds, workspaceId, connection); + }).block(); + + Map> projectStats = getProjectStats(List.of(project.id()), + workspaceId); + + return enhanceProject(project, projectLastUpdatedTraceAtMap, projectStats); + }) .orElseThrow(this::createNotFoundError); }); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java index 1a53041a5f..bbb99c6c5a 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java @@ -85,6 +85,8 @@ interface TraceDAO { Mono getStats(TraceSearchCriteria criteria); Mono getDailyTraces(); + + Mono> getStatsByProjectIds(List projectIds, String workspaceId); } @Slf4j @@ -609,7 +611,7 @@ LEFT JOIN ( if(length(metadata) > 0, 1, 0) as metadata_count, length(tags) as tags_count FROM traces - WHERE project_id = :project_id + WHERE project_id IN :project_ids AND workspace_id = :workspace_id AND @@ -621,7 +623,7 @@ AND id IN ( FROM feedback_scores WHERE entity_type = 'trace' AND workspace_id = :workspace_id - AND project_id = :project_id + AND project_id IN :project_ids ORDER BY entity_id DESC, last_updated_at DESC LIMIT 1 BY entity_id, name ) @@ -645,7 +647,7 @@ AND id IN ( total_estimated_cost FROM spans WHERE workspace_id = :workspace_id - AND project_id = :project_id + AND project_id IN :project_ids ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id ) @@ -669,7 +671,7 @@ LEFT JOIN ( total_estimated_cost FROM spans WHERE workspace_id = :workspace_id - AND project_id = :project_id + AND project_id IN :project_ids ORDER BY id DESC, last_updated_at DESC LIMIT 1 BY id ) @@ -692,7 +694,7 @@ LEFT JOIN ( FROM feedback_scores WHERE entity_type = 'trace' AND workspace_id = :workspace_id - AND project_id = :project_id + AND project_id IN :project_ids ORDER BY entity_id DESC, last_updated_at DESC LIMIT 1 BY entity_id, name ) GROUP BY project_id, entity_id @@ -1125,7 +1127,7 @@ public Mono getStats(@NonNull TraceSearchCriteria criteria) { ST statsSQL = newFindTemplate(SELECT_TRACES_STATS, criteria); var statement = connection.createStatement(statsSQL.render()) - .bind("project_id", criteria.projectId()); + .bind("project_ids", List.of(criteria.projectId()).toArray(UUID[]::new)); bindSearchCriteria(criteria, statement); @@ -1148,6 +1150,30 @@ public Mono getDailyTraces() { .reduce(0L, Long::sum); } + @Override + public Mono> getStatsByProjectIds(@NonNull List projectIds, + @NonNull String workspaceId) { + + if (projectIds.isEmpty()) { + return Mono.just(Map.of()); + } + + return asyncTemplate + .nonTransaction(connection -> { + Statement statement = connection.createStatement(new ST(SELECT_TRACES_STATS).render()) + .bind("project_ids", projectIds.toArray(UUID[]::new)) + .bind("workspace_id", workspaceId); + + return Mono.from(statement.execute()) + .flatMapMany(result -> result.map((row, rowMetadata) -> Map.of( + row.get("project_id", UUID.class), + StatsMapper.mapProjectStats(row, "trace_count")))) + .map(Map::entrySet) + .flatMap(Flux::fromIterable) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + }); + } + @Override @WithSpan public Mono> getLastUpdatedTraceAt( diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/stats/StatsMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/stats/StatsMapper.java index 25ce4d1d93..2bbb24c04d 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/stats/StatsMapper.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/stats/StatsMapper.java @@ -1,5 +1,6 @@ package com.comet.opik.domain.stats; +import com.comet.opik.api.FeedbackScoreAverage; import com.comet.opik.api.ProjectStats; import io.r2dbc.spi.Row; @@ -9,48 +10,59 @@ import java.util.Optional; import java.util.stream.Stream; +import static com.comet.opik.api.ProjectStats.AvgValueStat; +import static com.comet.opik.api.ProjectStats.CountValueStat; +import static com.comet.opik.api.ProjectStats.PercentageValueStat; +import static com.comet.opik.api.ProjectStats.PercentageValues; +import static java.util.stream.Collectors.toMap; + public class StatsMapper { + public static final String USAGE = "usage"; + public static final String FEEDBACK_SCORE = "feedback_scores"; + public static final String TOTAL_ESTIMATED_COST = "total_estimated_cost"; + public static final String DURATION = "duration"; + public static ProjectStats mapProjectStats(Row row, String entityCountLabel) { var stats = Stream.>builder() - .add(new ProjectStats.CountValueStat(entityCountLabel, + .add(new CountValueStat(entityCountLabel, row.get(entityCountLabel, Long.class))) - .add(new ProjectStats.PercentageValueStat("duration", Optional - .ofNullable(row.get("duration", List.class)) - .map(durations -> new ProjectStats.PercentageValues( + .add(new PercentageValueStat(DURATION, Optional + .ofNullable(row.get(DURATION, List.class)) + .map(durations -> new PercentageValues( getP(durations, 0), getP(durations, 1), getP(durations, 2))) .orElse(null))) - .add(new ProjectStats.CountValueStat("input", row.get("input", Long.class))) - .add(new ProjectStats.CountValueStat("output", row.get("output", Long.class))) - .add(new ProjectStats.CountValueStat("metadata", row.get("metadata", Long.class))) - .add(new ProjectStats.AvgValueStat("tags", row.get("tags", Double.class))); + .add(new CountValueStat("input", row.get("input", Long.class))) + .add(new CountValueStat("output", row.get("output", Long.class))) + .add(new CountValueStat("metadata", row.get("metadata", Long.class))) + .add(new AvgValueStat("tags", row.get("tags", Double.class))); BigDecimal totalEstimatedCost = row.get("total_estimated_cost_avg", BigDecimal.class); if (totalEstimatedCost == null) { totalEstimatedCost = BigDecimal.ZERO; } - stats.add(new ProjectStats.AvgValueStat("total_estimated_cost", totalEstimatedCost.doubleValue())); + stats.add(new AvgValueStat(TOTAL_ESTIMATED_COST, totalEstimatedCost.doubleValue())); - Map usage = row.get("usage", Map.class); - Map feedbackScores = row.get("feedback_scores", Map.class); + Map usage = row.get(USAGE, Map.class); + Map feedbackScores = row.get(FEEDBACK_SCORE, Map.class); if (usage != null) { usage.keySet() .stream() .sorted() .forEach(key -> stats - .add(new ProjectStats.AvgValueStat("%s.%s".formatted("usage", key), usage.get(key)))); + .add(new AvgValueStat("%s.%s".formatted(USAGE, key), usage.get(key)))); } if (feedbackScores != null) { feedbackScores.keySet() .stream() .sorted() - .forEach(key -> stats.add(new ProjectStats.AvgValueStat("%s.%s".formatted("feedback_score", key), + .forEach(key -> stats.add(new AvgValueStat("%s.%s".formatted(FEEDBACK_SCORE, key), feedbackScores.get(key)))); } @@ -61,4 +73,38 @@ private static BigDecimal getP(List durations, int index) { return durations.get(index); } + public static Map getStatsUsage(Map stats) { + return Optional.ofNullable(stats) + .map(map -> map.keySet() + .stream() + .filter(k -> k.startsWith(USAGE)) + .map( + k -> Map.entry(k.substring("%s.".formatted(USAGE).length()), (Double) map.get(k))) + .collect(toMap(Map.Entry::getKey, Map.Entry::getValue))) + .orElse(null); + } + + public static Double getStatsTotalEstimatedCost(Map stats) { + return Optional.ofNullable(stats) + .map(map -> map.get(TOTAL_ESTIMATED_COST)) + .map(v -> (Double) v) + .orElse(null); + } + + public static List getStatsFeedbackScores(Map stats) { + return Optional.ofNullable(stats) + .map(map -> map.keySet() + .stream() + .filter(k -> k.startsWith(FEEDBACK_SCORE)) + .map(k -> new FeedbackScoreAverage(k.substring("%s.".formatted(FEEDBACK_SCORE).length()), + BigDecimal.valueOf((Double) map.get(k)))) + .toList()) + .orElse(null); + } + + public static PercentageValues getStatsDuration(Map stats) { + return Optional.ofNullable(stats) + .map(map -> (PercentageValues) map.get(DURATION)) + .orElse(null); + } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/BigDecimalCollectors.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/BigDecimalCollectors.java new file mode 100644 index 0000000000..0de3476a8b --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/BigDecimalCollectors.java @@ -0,0 +1,53 @@ +package com.comet.opik.api.resources.utils; + +import com.comet.opik.utils.ValidationUtils; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.function.Function; +import java.util.stream.Collector; + +public class BigDecimalCollectors { + + public static Collector averagingBigDecimal() { + return Collector.of( + // Supplier: Create an array with two elements to hold total and count + () -> new BigDecimal[]{BigDecimal.ZERO, BigDecimal.ZERO}, + // Accumulator: Update total and count + (result, value) -> { + result[0] = result[0].add(value); // Accumulate total + result[1] = result[1].add(BigDecimal.ONE); // Increment count + }, + // Combiner: Merge two arrays (used for parallel streams) + (result1, result2) -> { + result1[0] = result1[0].add(result2[0]); // Combine totals + result1[1] = result1[1].add(result2[1]); // Combine counts + return result1; + }, + // Finisher: Compute the average (total / count) with rounding + result -> result[1].compareTo(BigDecimal.ZERO) == 0 + ? BigDecimal.ZERO // Avoid division by zero + : result[0].divide(result[1], ValidationUtils.SCALE, RoundingMode.HALF_UP)); + } + + public static Collector averagingBigDecimal(Function mapper) { + return Collector.of( + () -> new BigDecimal[]{BigDecimal.ZERO, BigDecimal.ZERO}, + (result, value) -> { + BigDecimal mappedValue = mapper.apply(value); + if (mappedValue != null) { + result[0] = result[0].add(mappedValue); + result[1] = result[1].add(BigDecimal.ONE); + } + }, + (result1, result2) -> { + result1[0] = result1[0].add(result2[0]); + result1[1] = result1[1].add(result2[1]); + return result1; + }, + result -> result[1].compareTo(BigDecimal.ZERO) == 0 + ? BigDecimal.ZERO + : result[0].divide(result[1], ValidationUtils.SCALE, RoundingMode.HALF_UP)); + } + +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java index 5e05510943..92e1cd4132 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/TraceResourceClient.java @@ -4,10 +4,12 @@ import com.comet.opik.api.FeedbackScore; import com.comet.opik.api.FeedbackScoreBatch; import com.comet.opik.api.FeedbackScoreBatchItem; +import com.comet.opik.api.Span; import com.comet.opik.api.Trace; import com.comet.opik.api.TraceBatch; import com.comet.opik.api.TraceUpdate; import com.comet.opik.api.resources.utils.TestUtils; +import com.comet.opik.domain.cost.ModelPrice; import jakarta.ws.rs.HttpMethod; import jakarta.ws.rs.client.Entity; import jakarta.ws.rs.core.HttpHeaders; @@ -16,8 +18,12 @@ import org.apache.http.HttpStatus; import ru.vyarus.dropwizard.guice.test.ClientSupport; +import java.math.BigDecimal; +import java.util.AbstractMap; import java.util.List; +import java.util.Map; import java.util.UUID; +import java.util.stream.Collectors; import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; import static org.assertj.core.api.Assertions.assertThat; @@ -140,4 +146,17 @@ public void updateTrace(UUID id, TraceUpdate traceUpdate, String apiKey, String assertThat(actualResponse.hasEntity()).isFalse(); } } + + public Map aggregateSpansUsage(List spans) { + return spans.stream() + .flatMap(span -> span.usage().entrySet().stream()) + .map(entry -> new AbstractMap.SimpleEntry<>(entry.getKey(), Long.valueOf(entry.getValue()))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, Long::sum)); + } + + public BigDecimal aggregateSpansCost(List spans) { + return spans.stream() + .map(span -> ModelPrice.fromString(span.model()).calculateCost(span.usage())) + .reduce(BigDecimal.ZERO, BigDecimal::add); + } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java index 6b85fc99ab..3cc21bd8ba 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java @@ -2,37 +2,46 @@ import com.comet.opik.TestComparators; import com.comet.opik.api.BatchDelete; +import com.comet.opik.api.FeedbackScore; +import com.comet.opik.api.FeedbackScoreAverage; +import com.comet.opik.api.FeedbackScoreBatchItem; import com.comet.opik.api.Project; import com.comet.opik.api.ProjectRetrieve; +import com.comet.opik.api.ProjectStats; import com.comet.opik.api.ProjectUpdate; +import com.comet.opik.api.Span; import com.comet.opik.api.Trace; import com.comet.opik.api.TraceUpdate; import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.BigDecimalCollectors; import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; import com.comet.opik.api.resources.utils.ClientSupportUtils; import com.comet.opik.api.resources.utils.MigrationUtils; import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.StatsUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; import com.comet.opik.api.resources.utils.TestUtils; import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.api.resources.utils.resources.SpanResourceClient; import com.comet.opik.api.resources.utils.resources.TraceResourceClient; import com.comet.opik.api.sorting.Direction; import com.comet.opik.api.sorting.SortableFields; import com.comet.opik.api.sorting.SortingFactory; import com.comet.opik.api.sorting.SortingField; -import com.comet.opik.domain.ProjectService; import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; import com.comet.opik.podam.PodamFactoryUtils; +import com.comet.opik.utils.DurationUtils; import com.comet.opik.utils.JsonUtils; +import com.comet.opik.utils.ValidationUtils; import com.github.tomakehurst.wiremock.client.WireMock; import com.redis.testcontainers.RedisContainer; import jakarta.ws.rs.HttpMethod; import jakarta.ws.rs.client.Entity; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; -import org.apache.hc.core5.http.HttpStatus; +import org.apache.http.HttpStatus; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -54,17 +63,22 @@ import ru.vyarus.dropwizard.guice.test.ClientSupport; import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import uk.co.jemos.podam.api.PodamFactory; +import uk.co.jemos.podam.api.PodamUtils; +import java.math.BigDecimal; +import java.math.RoundingMode; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.sql.SQLException; import java.time.Instant; +import java.util.Comparator; import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.UUID; import java.util.regex.Pattern; -import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -80,6 +94,9 @@ import static com.github.tomakehurst.wiremock.client.WireMock.okJson; import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static java.util.stream.Collectors.averagingDouble; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.toMap; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.params.provider.Arguments.arguments; @@ -91,7 +108,7 @@ class ProjectsResourceTest { public static final String URL_TEMPLATE = "%s/v1/private/projects"; public static final String URL_TEMPLATE_TRACE = "%s/v1/private/traces"; public static final String[] IGNORED_FIELDS = {"createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", - "lastUpdatedTraceAt"}; + "lastUpdatedTraceAt", "feedbackScores", "duration", "totalEstimatedCost", "usage"}; private static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); @@ -123,11 +140,11 @@ class ProjectsResourceTest { private String baseURI; private ClientSupport client; - private ProjectService projectService; private TraceResourceClient traceResourceClient; + private SpanResourceClient spanResourceClient; @BeforeAll - void setUpAll(ClientSupport client, Jdbi jdbi, ProjectService projectService) throws SQLException { + void setUpAll(ClientSupport client, Jdbi jdbi) throws SQLException { MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); @@ -138,13 +155,13 @@ void setUpAll(ClientSupport client, Jdbi jdbi, ProjectService projectService) th this.baseURI = "http://localhost:%d".formatted(client.getPort()); this.client = client; - this.projectService = projectService; ClientSupportUtils.config(client); mockTargetWorkspace(API_KEY, TEST_WORKSPACE, WORKSPACE_ID); this.traceResourceClient = new TraceResourceClient(this.client, baseURI); + this.spanResourceClient = new SpanResourceClient(this.client, baseURI); } private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { @@ -1090,10 +1107,202 @@ void getProjects__whenProjectsHasTraces__thenReturnProjectWithLastUpdatedTraceAt assertThat(actualEntity.content().get(2).lastUpdatedTraceAt()) .isEqualTo(expectedProject.lastUpdatedTraceAt()); - assertAllProjectsHavePersistedLastTraceAt(workspaceId, List.of(expectedProject, expectedProject2, + assertAllProjectsHavePersistedLastTraceAt(workspaceName, apiKey, List.of(expectedProject, expectedProject2, expectedProject3)); } + @Test + @DisplayName("when projects with traces, spans, feedback scores, and usage, then return project aggregations") + void getProjects__whenProjectsHasTracesSpansFeedbackScoresAndUsage__thenReturnProjectAggregations() { + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) + .parallelStream() + .map(project -> project.toBuilder() + .id(createProject(project, apiKey, workspaceName)) + .totalEstimatedCost(null) + .usage(null) + .feedbackScores(null) + .duration(null) + .build()) + .toList(); + + List expectedProjects = projects.parallelStream() + .map(project -> buildProjectStats(project, apiKey, workspaceName)) + .sorted(Comparator.comparing(Project::id).reversed()) + .toList(); + + var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get(); + + var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); + + assertThat(expectedProjects).hasSameSizeAs(actualEntity.content()); + + assertThat(actualEntity.content()) + .usingRecursiveComparison() + .ignoringFields("createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", "lastUpdatedTraceAt") + .ignoringCollectionOrder() + .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) + .isEqualTo(expectedProjects); + } + + @Test + @DisplayName("when projects without traces, spans, feedback scores, and usage, then return project aggregations") + void getProjects__whenProjectsHasNoTracesSpansFeedbackScoresAndUsage__thenReturnProjectAggregations() { + String workspaceName = UUID.randomUUID().toString(); + String apiKey = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var projects = PodamFactoryUtils.manufacturePojoList(factory, Project.class) + .parallelStream() + .map(project -> project.toBuilder() + .id(createProject(project, apiKey, workspaceName)) + .totalEstimatedCost(null) + .usage(null) + .feedbackScores(null) + .duration(null) + .build()) + .toList(); + + List expectedProjects = projects.parallelStream() + .map(project -> project.toBuilder() + .duration(null) + .totalEstimatedCost(null) + .usage(null) + .feedbackScores(null) + .build()) + .sorted(Comparator.comparing(Project::id).reversed()) + .toList(); + + var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get(); + + var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(org.apache.http.HttpStatus.SC_OK); + + assertThat(expectedProjects).hasSameSizeAs(actualEntity.content()); + + assertThat(actualEntity.content()) + .usingRecursiveComparison() + .ignoringFields("createdBy", "lastUpdatedBy", "createdAt", "lastUpdatedAt", "lastUpdatedTraceAt") + .ignoringCollectionOrder() + .withComparatorForType(StatsUtils::bigDecimalComparator, BigDecimal.class) + .isEqualTo(expectedProjects); + } + + private Project buildProjectStats(Project project, String apiKey, String workspaceName) { + var traces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream() + .map(trace -> { + Instant startTime = Instant.now(); + Instant endTime = startTime.plusMillis(PodamUtils.getIntegerInRange(1, 1000)); + return trace.toBuilder() + .projectName(project.name()) + .startTime(startTime) + .endTime(endTime) + .duration(DurationUtils.getDurationInMillisWithSubMilliPrecision(startTime, endTime)) + .build(); + }) + .toList(); + + traceResourceClient.batchCreateTraces(traces, apiKey, workspaceName); + + List scores = PodamFactoryUtils.manufacturePojoList(factory, + FeedbackScoreBatchItem.class); + + traces = traces.stream().map(trace -> { + List spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream() + .map(span -> span.toBuilder() + .usage(spanResourceClient.getTokenUsage()) + .model(spanResourceClient.randomModelPrice().getName()) + .traceId(trace.id()) + .projectName(trace.projectName()) + .build()) + .toList(); + + spanResourceClient.batchCreateSpans(spans, apiKey, workspaceName); + + List feedbackScores = scores.stream() + .map(feedbackScore -> feedbackScore.toBuilder() + .projectId(project.id()) + .projectName(project.name()) + .id(trace.id()) + .build()) + .toList(); + + traceResourceClient.feedbackScores(feedbackScores, apiKey, workspaceName); + + return trace.toBuilder() + .feedbackScores( + feedbackScores.stream() + .map(score -> FeedbackScore.builder() + .value(score.value()) + .name(score.name()) + .build()) + .toList()) + .usage(traceResourceClient.aggregateSpansUsage(spans)) + .totalEstimatedCost(traceResourceClient.aggregateSpansCost(spans)) + .build(); + }).toList(); + + List durations = StatsUtils.calculateQuantiles( + traces.stream() + .map(Trace::duration) + .toList(), + List.of(0.5, 0.90, 0.99)); + + return project.toBuilder() + .duration(new ProjectStats.PercentageValues(durations.get(0), durations.get(1), durations.get(2))) + .totalEstimatedCost(getTotalEstimatedCost(traces)) + .usage(traces.stream() + .map(Trace::usage) + .flatMap(usage -> usage.entrySet().stream()) + .collect(groupingBy(Map.Entry::getKey, averagingDouble(Map.Entry::getValue)))) + .feedbackScores(getScoreAverages(traces)) + .build(); + } + + private List getScoreAverages(List traces) { + return traces.stream() + .map(Trace::feedbackScores) + .flatMap(List::stream) + .collect(groupingBy(FeedbackScore::name, + BigDecimalCollectors.averagingBigDecimal(FeedbackScore::value))) + .entrySet() + .stream() + .map(entry -> FeedbackScoreAverage.builder() + .name(entry.getKey()) + .value(entry.getValue()) + .build()) + .toList(); + } + + private double getTotalEstimatedCost(List traces) { + long count = traces.stream() + .map(Trace::totalEstimatedCost) + .filter(Objects::nonNull) + .filter(cost -> cost.compareTo(BigDecimal.ZERO) > 0) + .count(); + + return traces.stream() + .map(Trace::totalEstimatedCost) + .reduce(BigDecimal.ZERO, BigDecimal::add) + .divide(BigDecimal.valueOf(count), ValidationUtils.SCALE, RoundingMode.HALF_UP).doubleValue(); + } + @Test @DisplayName("when projects is with traces created in batch, then return project with last updated trace at") void getProjects__whenProjectsHasTracesBatch__thenReturnProjectWithLastUpdatedTraceAt() { @@ -1160,7 +1369,7 @@ void getProjects__whenProjectsHasTracesBatch__thenReturnProjectWithLastUpdatedTr assertThat(actualEntity.content().get(2).lastUpdatedTraceAt()) .isEqualTo(expectedProject.lastUpdatedTraceAt()); - assertAllProjectsHavePersistedLastTraceAt(workspaceId, List.of(expectedProject, expectedProject2, + assertAllProjectsHavePersistedLastTraceAt(workspaceName, apiKey, List.of(expectedProject, expectedProject2, expectedProject3)); } @@ -1190,21 +1399,35 @@ void getProjects__whenTraceIsUpdated__thenUpdateProjectsLastUpdatedTraceAt() { Project expectedProject = project.toBuilder().id(projectId).lastUpdatedTraceAt(trace.lastUpdatedAt()) .build(); - assertAllProjectsHavePersistedLastTraceAt(workspaceId, List.of(expectedProject)); + assertAllProjectsHavePersistedLastTraceAt(workspaceName, apiKey, List.of(expectedProject)); } - private void assertAllProjectsHavePersistedLastTraceAt(String workspaceId, List expectedProjects) { - List dbProjects = projectService.findByIds(workspaceId, expectedProjects.stream() - .map(Project::id).collect(Collectors.toUnmodifiableSet())); - - for (Project project : expectedProjects) { - Awaitility.await().untilAsserted(() -> { - assertThat(dbProjects.stream().filter(dbProject -> dbProject.id().equals(project.id())) - .findFirst().orElseThrow().lastUpdatedTraceAt()) - .usingComparator(TestComparators::compareMicroNanoTime) - .isEqualTo(project.lastUpdatedTraceAt()); - }); - } + private void assertAllProjectsHavePersistedLastTraceAt(String workspaceName, String apiKey, + List expectedProjects) { + + Awaitility.await().untilAsserted(() -> { + var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .queryParam("size", 100) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get(); + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + var actualEntity = actualResponse.readEntity(Project.ProjectPage.class); + + assertThat(actualEntity.size()).isEqualTo(expectedProjects.size()); + + Map actualProjectsByLastTraceAt = actualEntity.content() + .stream() + .collect(toMap(Project::id, Project::lastUpdatedTraceAt)); + + assertThat(actualProjectsByLastTraceAt) + .usingRecursiveComparison() + .withComparatorForType(TestComparators::compareMicroNanoTime, Instant.class) + .isEqualTo(expectedProjects.stream() + .collect(toMap(Project::id, Project::lastUpdatedTraceAt))); + }); } } @@ -1274,17 +1497,8 @@ private UUID createCreateTrace(String projectName, String apiKey, String workspa .projectName(projectName) .build(); - try (var actualResponse = client.target(URL_TEMPLATE_TRACE.formatted(baseURI)) - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .post(Entity.json(trace))) { - - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - assertThat(actualResponse.hasEntity()).isFalse(); - - return TestUtils.getIdFromLocation(actualResponse.getLocation()); - } + traceResourceClient.batchCreateTraces(List.of(trace), apiKey, workspaceName); + return trace.id(); } private Trace getTrace(UUID id, String apiKey, String workspaceName) {