Skip to content

Commit

Permalink
[OPIK-287] Add project level aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagohora committed Dec 14, 2024
1 parent 3faf989 commit fb94fbd
Show file tree
Hide file tree
Showing 7 changed files with 472 additions and 57 deletions.
13 changes: 12 additions & 1 deletion apps/opik-backend/src/main/java/com/comet/opik/api/Project.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import static com.comet.opik.api.ProjectStats.PercentageValues;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
// This annotation is used to specify the strategy to be used for naming of properties for the annotated type. Required so that OpenAPI schema generation uses snake_case
Expand All @@ -29,7 +32,15 @@ public record Project(
@JsonView({Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt,
@JsonView({Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy,
@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Instant lastUpdatedTraceAt){
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Instant lastUpdatedTraceAt,
@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable List<FeedbackScoreAverage> feedbackScores,
@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PercentageValues duration,
@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Double totalEstimatedCost,
@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Map<String, Double> usage){

public static class View {
public static class Write {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.comet.opik.api.Project.ProjectPage;
import com.comet.opik.api.ProjectCriteria;
import com.comet.opik.api.ProjectIdLastUpdated;
import com.comet.opik.api.ProjectStats;
import com.comet.opik.api.ProjectUpdate;
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.api.error.ErrorMessage;
Expand All @@ -13,6 +14,7 @@
import com.comet.opik.api.sorting.SortingFactoryProjects;
import com.comet.opik.api.sorting.SortingField;
import com.comet.opik.domain.sorting.SortingQueryBuilder;
import com.comet.opik.domain.stats.StatsMapper;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.db.TransactionTemplateAsync;
import com.comet.opik.utils.PaginationUtils;
Expand Down Expand Up @@ -44,6 +46,7 @@
import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY;
import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE;
import static java.util.Collections.reverseOrder;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;
import static java.util.stream.Collectors.toUnmodifiableSet;

Expand Down Expand Up @@ -199,8 +202,19 @@ public Project get(@NonNull UUID id, @NonNull String workspaceId) {
.nonTransaction(connection -> traceDAO.getLastUpdatedTraceAt(Set.of(id), workspaceId, connection))
.block();

Map<UUID, Map<String, Object>> projectStats = getProjectStats(List.of(id), workspaceId);

return enhanceProject(project, lastUpdatedTraceAt, projectStats);
}

private Project enhanceProject(Project project, Map<UUID, Instant> lastUpdatedTraceAt,
Map<UUID, Map<String, Object>> projectStats) {
return project.toBuilder()
.lastUpdatedTraceAt(lastUpdatedTraceAt.get(id))
.lastUpdatedTraceAt(lastUpdatedTraceAt.get(project.id()))
.feedbackScores(StatsMapper.getStatsFeedbackScores(projectStats.get(project.id())))
.duration(StatsMapper.getStatsDuration(projectStats.get(project.id())))
.totalEstimatedCost(StatsMapper.getStatsTotalEstimatedCost(projectStats.get(project.id())))
.usage(StatsMapper.getStatsUsage(projectStats.get(project.id())))
.build();
}

Expand Down Expand Up @@ -271,17 +285,36 @@ public Page<Project> find(int page, int size, @NonNull ProjectCriteria criteria,
return traceDAO.getLastUpdatedTraceAt(projectIds, workspaceId, connection);
}).block();

List<UUID> projectIds = projectRecordSet.content.stream().map(Project::id).toList();

Map<UUID, Map<String, Object>> projectStats = getProjectStats(projectIds, workspaceId);

List<Project> projects = projectRecordSet.content()
.stream()
.map(project -> project.toBuilder()
.lastUpdatedTraceAt(projectLastUpdatedTraceAtMap.get(project.id()))
.build())
.map(project -> enhanceProject(project, projectLastUpdatedTraceAtMap, projectStats))
.toList();

return new ProjectPage(page, projects.size(), projectRecordSet.total(), projects,
sortingFactory.getSortableFields());
}

private Map<UUID, Map<String, Object>> getProjectStats(List<UUID> projectIds, String workspaceId) {
return traceDAO.getStatsByProjectIds(projectIds, workspaceId)
.map(stats -> stats.entrySet().stream()
.map(entry -> {
Map<String, Object> statsMap = entry.getValue()
.stats()
.stream()
.collect(toMap(ProjectStats.ProjectStatItem::getName,
ProjectStats.ProjectStatItem::getValue));

return Map.entry(entry.getKey(), statsMap);
})
.map(entry -> Map.entry(entry.getKey(), entry.getValue()))
.collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))
.block();
}

@Override
public List<Project> findByIds(String workspaceId, Set<UUID> ids) {
if (ids.isEmpty()) {
Expand Down Expand Up @@ -403,6 +436,19 @@ public Project retrieveByName(@NonNull String projectName) {
return repository.findByNames(workspaceId, List.of(projectName))
.stream()
.findFirst()
.map(project -> {

Map<UUID, Instant> projectLastUpdatedTraceAtMap = transactionTemplateAsync
.nonTransaction(connection -> {
Set<UUID> projectIds = Set.of(project.id());
return traceDAO.getLastUpdatedTraceAt(projectIds, workspaceId, connection);
}).block();

Map<UUID, Map<String, Object>> projectStats = getProjectStats(List.of(project.id()),
workspaceId);

return enhanceProject(project, projectLastUpdatedTraceAtMap, projectStats);
})
.orElseThrow(this::createNotFoundError);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ interface TraceDAO {
Mono<ProjectStats> getStats(TraceSearchCriteria criteria);

Mono<Long> getDailyTraces();

Mono<Map<UUID, ProjectStats>> getStatsByProjectIds(List<UUID> projectIds, String workspaceId);
}

@Slf4j
Expand Down Expand Up @@ -609,7 +611,7 @@ LEFT JOIN (
if(length(metadata) > 0, 1, 0) as metadata_count,
length(tags) as tags_count
FROM traces
WHERE project_id = :project_id
WHERE project_id IN :project_ids
AND workspace_id = :workspace_id
<if(filters)> AND <filters> <endif>
<if(feedback_scores_filters)>
Expand All @@ -621,7 +623,7 @@ AND id IN (
FROM feedback_scores
WHERE entity_type = 'trace'
AND workspace_id = :workspace_id
AND project_id = :project_id
AND project_id IN :project_ids
ORDER BY entity_id DESC, last_updated_at DESC
LIMIT 1 BY entity_id, name
)
Expand All @@ -645,7 +647,7 @@ AND id IN (
total_estimated_cost
FROM spans
WHERE workspace_id = :workspace_id
AND project_id = :project_id
AND project_id IN :project_ids
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
)
Expand All @@ -669,7 +671,7 @@ LEFT JOIN (
total_estimated_cost
FROM spans
WHERE workspace_id = :workspace_id
AND project_id = :project_id
AND project_id IN :project_ids
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
)
Expand All @@ -692,7 +694,7 @@ LEFT JOIN (
FROM feedback_scores
WHERE entity_type = 'trace'
AND workspace_id = :workspace_id
AND project_id = :project_id
AND project_id IN :project_ids
ORDER BY entity_id DESC, last_updated_at DESC
LIMIT 1 BY entity_id, name
) GROUP BY project_id, entity_id
Expand Down Expand Up @@ -1125,7 +1127,7 @@ public Mono<ProjectStats> getStats(@NonNull TraceSearchCriteria criteria) {
ST statsSQL = newFindTemplate(SELECT_TRACES_STATS, criteria);

var statement = connection.createStatement(statsSQL.render())
.bind("project_id", criteria.projectId());
.bind("project_ids", List.of(criteria.projectId()).toArray(UUID[]::new));

bindSearchCriteria(criteria, statement);

Expand All @@ -1148,6 +1150,30 @@ public Mono<Long> getDailyTraces() {
.reduce(0L, Long::sum);
}

@Override
public Mono<Map<UUID, ProjectStats>> getStatsByProjectIds(@NonNull List<UUID> projectIds,
@NonNull String workspaceId) {

if (projectIds.isEmpty()) {
return Mono.just(Map.of());
}

return asyncTemplate
.nonTransaction(connection -> {
Statement statement = connection.createStatement(new ST(SELECT_TRACES_STATS).render())
.bind("project_ids", projectIds.toArray(UUID[]::new))
.bind("workspace_id", workspaceId);

return Mono.from(statement.execute())
.flatMapMany(result -> result.map((row, rowMetadata) -> Map.of(
row.get("project_id", UUID.class),
StatsMapper.mapProjectStats(row, "trace_count"))))
.map(Map::entrySet)
.flatMap(Flux::fromIterable)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
});
}

@Override
@WithSpan
public Mono<Map<UUID, Instant>> getLastUpdatedTraceAt(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.comet.opik.domain.stats;

import com.comet.opik.api.FeedbackScoreAverage;
import com.comet.opik.api.ProjectStats;
import io.r2dbc.spi.Row;

Expand All @@ -9,48 +10,59 @@
import java.util.Optional;
import java.util.stream.Stream;

import static com.comet.opik.api.ProjectStats.AvgValueStat;
import static com.comet.opik.api.ProjectStats.CountValueStat;
import static com.comet.opik.api.ProjectStats.PercentageValueStat;
import static com.comet.opik.api.ProjectStats.PercentageValues;
import static java.util.stream.Collectors.toMap;

public class StatsMapper {

public static final String USAGE = "usage";
public static final String FEEDBACK_SCORE = "feedback_scores";
public static final String TOTAL_ESTIMATED_COST = "total_estimated_cost";
public static final String DURATION = "duration";

public static ProjectStats mapProjectStats(Row row, String entityCountLabel) {

var stats = Stream.<ProjectStats.ProjectStatItem<?>>builder()
.add(new ProjectStats.CountValueStat(entityCountLabel,
.add(new CountValueStat(entityCountLabel,
row.get(entityCountLabel, Long.class)))
.add(new ProjectStats.PercentageValueStat("duration", Optional
.ofNullable(row.get("duration", List.class))
.map(durations -> new ProjectStats.PercentageValues(
.add(new PercentageValueStat(DURATION, Optional
.ofNullable(row.get(DURATION, List.class))
.map(durations -> new PercentageValues(
getP(durations, 0),
getP(durations, 1),
getP(durations, 2)))
.orElse(null)))
.add(new ProjectStats.CountValueStat("input", row.get("input", Long.class)))
.add(new ProjectStats.CountValueStat("output", row.get("output", Long.class)))
.add(new ProjectStats.CountValueStat("metadata", row.get("metadata", Long.class)))
.add(new ProjectStats.AvgValueStat("tags", row.get("tags", Double.class)));
.add(new CountValueStat("input", row.get("input", Long.class)))
.add(new CountValueStat("output", row.get("output", Long.class)))
.add(new CountValueStat("metadata", row.get("metadata", Long.class)))
.add(new AvgValueStat("tags", row.get("tags", Double.class)));

BigDecimal totalEstimatedCost = row.get("total_estimated_cost_avg", BigDecimal.class);
if (totalEstimatedCost == null) {
totalEstimatedCost = BigDecimal.ZERO;
}

stats.add(new ProjectStats.AvgValueStat("total_estimated_cost", totalEstimatedCost.doubleValue()));
stats.add(new AvgValueStat(TOTAL_ESTIMATED_COST, totalEstimatedCost.doubleValue()));

Map<String, Double> usage = row.get("usage", Map.class);
Map<String, Double> feedbackScores = row.get("feedback_scores", Map.class);
Map<String, Double> usage = row.get(USAGE, Map.class);
Map<String, Double> feedbackScores = row.get(FEEDBACK_SCORE, Map.class);

if (usage != null) {
usage.keySet()
.stream()
.sorted()
.forEach(key -> stats
.add(new ProjectStats.AvgValueStat("%s.%s".formatted("usage", key), usage.get(key))));
.add(new AvgValueStat("%s.%s".formatted(USAGE, key), usage.get(key))));
}

if (feedbackScores != null) {
feedbackScores.keySet()
.stream()
.sorted()
.forEach(key -> stats.add(new ProjectStats.AvgValueStat("%s.%s".formatted("feedback_score", key),
.forEach(key -> stats.add(new AvgValueStat("%s.%s".formatted(FEEDBACK_SCORE, key),
feedbackScores.get(key))));
}

Expand All @@ -61,4 +73,38 @@ private static BigDecimal getP(List<BigDecimal> durations, int index) {
return durations.get(index);
}

public static Map<String, Double> getStatsUsage(Map<String, Object> stats) {
return Optional.ofNullable(stats)
.map(map -> map.keySet()
.stream()
.filter(k -> k.startsWith(USAGE))
.map(
k -> Map.entry(k.substring("%s.".formatted(USAGE).length()), (Double) map.get(k)))
.collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))
.orElse(null);
}

public static Double getStatsTotalEstimatedCost(Map<String, ?> stats) {
return Optional.ofNullable(stats)
.map(map -> map.get(TOTAL_ESTIMATED_COST))
.map(v -> (Double) v)
.orElse(null);
}

public static List<FeedbackScoreAverage> getStatsFeedbackScores(Map<String, ?> stats) {
return Optional.ofNullable(stats)
.map(map -> map.keySet()
.stream()
.filter(k -> k.startsWith(FEEDBACK_SCORE))
.map(k -> new FeedbackScoreAverage(k.substring("%s.".formatted(FEEDBACK_SCORE).length()),
BigDecimal.valueOf((Double) map.get(k))))
.toList())
.orElse(null);
}

public static PercentageValues getStatsDuration(Map<String, ?> stats) {
return Optional.ofNullable(stats)
.map(map -> (PercentageValues) map.get(DURATION))
.orElse(null);
}
}
Loading

0 comments on commit fb94fbd

Please sign in to comment.