Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-287] Add project level aggregations #894

Merged
merged 22 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import static com.comet.opik.api.ProjectStats.PercentageValues;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
// This annotation is used to specify the strategy to be used for naming of properties for the annotated type. Required so that OpenAPI schema generation uses snake_case
Expand All @@ -29,7 +32,15 @@ public record Project(
@JsonView({Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt,
@JsonView({Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy,
@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Instant lastUpdatedTraceAt){
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Instant lastUpdatedTraceAt,
@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable List<FeedbackScoreAverage> feedbackScores,
@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PercentageValues duration,
@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Double totalEstimatedCost,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Monetary cost as floating type. For other aggregations such as avg, it wouldn't be that bad. But for a total, I assume is a sum, better as BigDecimal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is a bit misleading. This holds the average total cost of a trace in a project. This is calculated here:

avgIf(total_estimated_cost, total_estimated_cost > 0) AS total_estimated_cost_,

In addition, you can see that the coverage for this logic asserts on an average:

return traces.stream()
.map(Trace::totalEstimatedCost)
.reduce(BigDecimal.ZERO, BigDecimal::add)
.divide(BigDecimal.valueOf(count), ValidationUtils.SCALE, RoundingMode.HALF_UP).doubleValue();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename it to something more significative and less confusing then? e.g: avgTotalEstimatedCost or similar.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are arguable not clear names, in comparison with other aggregations such as lastUpdatedTraceAt. On the other hand, List<FeedbackScoreAverage> feedbackScores, already exists in Experiments.

Not a blocker so far.

@JsonView({
Project.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Map<String, Double> usage){

public static class View {
public static class Write {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
@ImplementedBy(ProjectMetricsServiceImpl.class)
public interface ProjectMetricsService {
String ERR_START_BEFORE_END = "'start_time' must be before 'end_time'";
String ERR_PROJECT_METRIC_NOT_SUPPORTED = "metric '%s' is not supported";

Mono<ProjectMetricResponse<Number>> getProjectMetrics(UUID projectId, ProjectMetricRequest request);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.comet.opik.api.Project.ProjectPage;
import com.comet.opik.api.ProjectCriteria;
import com.comet.opik.api.ProjectIdLastUpdated;
import com.comet.opik.api.ProjectStats;
import com.comet.opik.api.ProjectUpdate;
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.api.error.ErrorMessage;
Expand All @@ -13,6 +14,7 @@
import com.comet.opik.api.sorting.SortingFactoryProjects;
import com.comet.opik.api.sorting.SortingField;
import com.comet.opik.domain.sorting.SortingQueryBuilder;
import com.comet.opik.domain.stats.StatsMapper;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.db.TransactionTemplateAsync;
import com.comet.opik.utils.PaginationUtils;
Expand Down Expand Up @@ -44,6 +46,7 @@
import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY;
import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE;
import static java.util.Collections.reverseOrder;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;
import static java.util.stream.Collectors.toUnmodifiableSet;

Expand Down Expand Up @@ -199,8 +202,18 @@ public Project get(@NonNull UUID id, @NonNull String workspaceId) {
.nonTransaction(connection -> traceDAO.getLastUpdatedTraceAt(Set.of(id), workspaceId, connection))
.block();

Map<UUID, Map<String, Object>> projectStats = getProjectStats(List.of(id), workspaceId);

return enhanceProject(project, lastUpdatedTraceAt.get(project.id()), projectStats.get(project.id()));
}

private Project enhanceProject(Project project, Instant lastUpdatedTraceAt, Map<String, Object> projectStats) {
return project.toBuilder()
.lastUpdatedTraceAt(lastUpdatedTraceAt.get(id))
.lastUpdatedTraceAt(lastUpdatedTraceAt)
.feedbackScores(StatsMapper.getStatsFeedbackScores(projectStats))
.duration(StatsMapper.getStatsDuration(projectStats))
.totalEstimatedCost(StatsMapper.getStatsTotalEstimatedCost(projectStats))
.usage(StatsMapper.getStatsUsage(projectStats))
.build();
}

Expand Down Expand Up @@ -271,17 +284,37 @@ public Page<Project> find(int page, int size, @NonNull ProjectCriteria criteria,
return traceDAO.getLastUpdatedTraceAt(projectIds, workspaceId, connection);
}).block();

List<UUID> projectIds = projectRecordSet.content.stream().map(Project::id).toList();

Map<UUID, Map<String, Object>> projectStats = getProjectStats(projectIds, workspaceId);

List<Project> projects = projectRecordSet.content()
.stream()
.map(project -> project.toBuilder()
.lastUpdatedTraceAt(projectLastUpdatedTraceAtMap.get(project.id()))
.build())
.map(project -> enhanceProject(project, projectLastUpdatedTraceAtMap.get(project.id()),
projectStats.get(project.id())))
.toList();

return new ProjectPage(page, projects.size(), projectRecordSet.total(), projects,
sortingFactory.getSortableFields());
}

private Map<UUID, Map<String, Object>> getProjectStats(List<UUID> projectIds, String workspaceId) {
return traceDAO.getStatsByProjectIds(projectIds, workspaceId)
.map(stats -> stats.entrySet().stream()
.map(entry -> {
Map<String, Object> statsMap = entry.getValue()
.stats()
.stream()
.collect(toMap(ProjectStats.ProjectStatItem::getName,
ProjectStats.ProjectStatItem::getValue));

return Map.entry(entry.getKey(), statsMap);
})
.map(entry -> Map.entry(entry.getKey(), entry.getValue()))
.collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))
.block();
}

@Override
public List<Project> findByIds(String workspaceId, Set<UUID> ids) {
if (ids.isEmpty()) {
Expand Down Expand Up @@ -328,10 +361,12 @@ private Page<Project> findWithLastTraceSorting(int page, int size, @NonNull Proj
return repository.findByIds(new HashSet<>(finalIds), workspaceId);
}).stream().collect(Collectors.toMap(Project::id, Function.identity()));

Map<UUID, Map<String, Object>> projectStats = getProjectStats(finalIds, workspaceId);

// compose the final projects list by the correct order and add last trace to it
List<Project> projects = finalIds.stream().map(id -> projectsById.get(id).toBuilder()
.lastUpdatedTraceAt(projectLastUpdatedTraceAtMap.get(id))
.build())
List<Project> projects = finalIds.stream().map(projectsById::get)
.map(project -> enhanceProject(project, projectLastUpdatedTraceAtMap.get(project.id()),
projectStats.get(project.id())))
.toList();

return new ProjectPage(page, projects.size(), allProjectIdsLastUpdated.size(), projects,
Expand Down Expand Up @@ -403,6 +438,20 @@ public Project retrieveByName(@NonNull String projectName) {
return repository.findByNames(workspaceId, List.of(projectName))
.stream()
.findFirst()
.map(project -> {

Map<UUID, Instant> projectLastUpdatedTraceAtMap = transactionTemplateAsync
.nonTransaction(connection -> {
Set<UUID> projectIds = Set.of(project.id());
return traceDAO.getLastUpdatedTraceAt(projectIds, workspaceId, connection);
}).block();

Map<UUID, Map<String, Object>> projectStats = getProjectStats(List.of(project.id()),
workspaceId);

return enhanceProject(project, projectLastUpdatedTraceAtMap.get(project.id()),
projectStats.get(project.id()));
})
.orElseThrow(this::createNotFoundError);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ interface TraceDAO {
Mono<ProjectStats> getStats(TraceSearchCriteria criteria);

Mono<Long> getDailyTraces();

Mono<Map<UUID, ProjectStats>> getStatsByProjectIds(List<UUID> projectIds, String workspaceId);
}

@Slf4j
Expand Down Expand Up @@ -639,7 +641,7 @@ AND notEquals(start_time, toDateTime64('1970-01-01 00:00:00.000', 9)),
if(length(metadata) > 0, 1, 0) as metadata_count,
length(tags) as tags_count
FROM traces
WHERE project_id = :project_id
WHERE project_id IN :project_ids
AND workspace_id = :workspace_id
<if(filters)> AND <filters> <endif>
<if(feedback_scores_filters)>
Expand All @@ -651,7 +653,7 @@ AND id IN (
FROM feedback_scores
WHERE entity_type = 'trace'
AND workspace_id = :workspace_id
AND project_id = :project_id
AND project_id IN :project_ids
ORDER BY entity_id DESC, last_updated_at DESC
LIMIT 1 BY entity_id, name
)
Expand All @@ -675,7 +677,7 @@ AND id IN (
total_estimated_cost
FROM spans
WHERE workspace_id = :workspace_id
AND project_id = :project_id
AND project_id IN :project_ids
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
)
Expand All @@ -699,7 +701,7 @@ LEFT JOIN (
total_estimated_cost
FROM spans
WHERE workspace_id = :workspace_id
AND project_id = :project_id
AND project_id IN :project_ids
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
)
Expand All @@ -722,7 +724,7 @@ LEFT JOIN (
FROM feedback_scores
WHERE entity_type = 'trace'
AND workspace_id = :workspace_id
AND project_id = :project_id
AND project_id IN :project_ids
ORDER BY entity_id DESC, last_updated_at DESC
LIMIT 1 BY entity_id, name
) GROUP BY project_id, entity_id
Expand Down Expand Up @@ -1174,7 +1176,7 @@ public Mono<ProjectStats> getStats(@NonNull TraceSearchCriteria criteria) {
ST statsSQL = newFindTemplate(SELECT_TRACES_STATS, criteria);

var statement = connection.createStatement(statsSQL.render())
.bind("project_id", criteria.projectId());
.bind("project_ids", List.of(criteria.projectId()));

bindSearchCriteria(criteria, statement);

Expand All @@ -1197,6 +1199,30 @@ public Mono<Long> getDailyTraces() {
.reduce(0L, Long::sum);
}

@Override
public Mono<Map<UUID, ProjectStats>> getStatsByProjectIds(@NonNull List<UUID> projectIds,
@NonNull String workspaceId) {

if (projectIds.isEmpty()) {
return Mono.just(Map.of());
}

return asyncTemplate
.nonTransaction(connection -> {
Statement statement = connection.createStatement(new ST(SELECT_TRACES_STATS).render())
.bind("project_ids", projectIds)
.bind("workspace_id", workspaceId);

return Mono.from(statement.execute())
.flatMapMany(result -> result.map((row, rowMetadata) -> Map.of(
row.get("project_id", UUID.class),
StatsMapper.mapProjectStats(row, "trace_count"))))
.map(Map::entrySet)
.flatMap(Flux::fromIterable)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
});
}

@Override
@WithSpan
public Mono<Map<UUID, Instant>> getLastUpdatedTraceAt(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.comet.opik.domain.stats;

import com.comet.opik.api.FeedbackScoreAverage;
import com.comet.opik.api.ProjectStats;
import io.r2dbc.spi.Row;

Expand All @@ -9,48 +10,59 @@
import java.util.Optional;
import java.util.stream.Stream;

import static com.comet.opik.api.ProjectStats.AvgValueStat;
import static com.comet.opik.api.ProjectStats.CountValueStat;
import static com.comet.opik.api.ProjectStats.PercentageValueStat;
import static com.comet.opik.api.ProjectStats.PercentageValues;
import static java.util.stream.Collectors.toMap;

public class StatsMapper {

public static final String USAGE = "usage";
public static final String FEEDBACK_SCORE = "feedback_scores";
public static final String TOTAL_ESTIMATED_COST = "total_estimated_cost";
public static final String DURATION = "duration";

public static ProjectStats mapProjectStats(Row row, String entityCountLabel) {

var stats = Stream.<ProjectStats.ProjectStatItem<?>>builder()
.add(new ProjectStats.CountValueStat(entityCountLabel,
.add(new CountValueStat(entityCountLabel,
row.get(entityCountLabel, Long.class)))
.add(new ProjectStats.PercentageValueStat("duration", Optional
.ofNullable(row.get("duration", List.class))
.map(durations -> new ProjectStats.PercentageValues(
.add(new PercentageValueStat(DURATION, Optional
.ofNullable(row.get(DURATION, List.class))
.map(durations -> new PercentageValues(
getP(durations, 0),
getP(durations, 1),
getP(durations, 2)))
.orElse(null)))
.add(new ProjectStats.CountValueStat("input", row.get("input", Long.class)))
.add(new ProjectStats.CountValueStat("output", row.get("output", Long.class)))
.add(new ProjectStats.CountValueStat("metadata", row.get("metadata", Long.class)))
.add(new ProjectStats.AvgValueStat("tags", row.get("tags", Double.class)));
.add(new CountValueStat("input", row.get("input", Long.class)))
.add(new CountValueStat("output", row.get("output", Long.class)))
.add(new CountValueStat("metadata", row.get("metadata", Long.class)))
.add(new AvgValueStat("tags", row.get("tags", Double.class)));

BigDecimal totalEstimatedCost = row.get("total_estimated_cost_avg", BigDecimal.class);
if (totalEstimatedCost == null) {
totalEstimatedCost = BigDecimal.ZERO;
}

stats.add(new ProjectStats.AvgValueStat("total_estimated_cost", totalEstimatedCost.doubleValue()));
stats.add(new AvgValueStat(TOTAL_ESTIMATED_COST, totalEstimatedCost.doubleValue()));

Map<String, Double> usage = row.get("usage", Map.class);
Map<String, Double> feedbackScores = row.get("feedback_scores", Map.class);
Map<String, Double> usage = row.get(USAGE, Map.class);
Map<String, Double> feedbackScores = row.get(FEEDBACK_SCORE, Map.class);

if (usage != null) {
usage.keySet()
.stream()
.sorted()
.forEach(key -> stats
.add(new ProjectStats.AvgValueStat("%s.%s".formatted("usage", key), usage.get(key))));
.add(new AvgValueStat("%s.%s".formatted(USAGE, key), usage.get(key))));
}

if (feedbackScores != null) {
feedbackScores.keySet()
.stream()
.sorted()
.forEach(key -> stats.add(new ProjectStats.AvgValueStat("%s.%s".formatted("feedback_score", key),
.forEach(key -> stats.add(new AvgValueStat("%s.%s".formatted(FEEDBACK_SCORE, key),
feedbackScores.get(key))));
}

Expand All @@ -61,4 +73,38 @@ private static BigDecimal getP(List<BigDecimal> durations, int index) {
return durations.get(index);
}

public static Map<String, Double> getStatsUsage(Map<String, Object> stats) {
return Optional.ofNullable(stats)
.map(map -> map.keySet()
.stream()
.filter(k -> k.startsWith(USAGE))
.map(
k -> Map.entry(k.substring("%s.".formatted(USAGE).length()), (Double) map.get(k)))
.collect(toMap(Map.Entry::getKey, Map.Entry::getValue)))
.orElse(null);
}

public static Double getStatsTotalEstimatedCost(Map<String, ?> stats) {
return Optional.ofNullable(stats)
.map(map -> map.get(TOTAL_ESTIMATED_COST))
.map(v -> (Double) v)
.orElse(null);
}

public static List<FeedbackScoreAverage> getStatsFeedbackScores(Map<String, ?> stats) {
return Optional.ofNullable(stats)
.map(map -> map.keySet()
.stream()
.filter(k -> k.startsWith(FEEDBACK_SCORE))
.map(k -> new FeedbackScoreAverage(k.substring("%s.".formatted(FEEDBACK_SCORE).length()),
BigDecimal.valueOf((Double) map.get(k))))
.toList())
.orElse(null);
}

public static PercentageValues getStatsDuration(Map<String, ?> stats) {
return Optional.ofNullable(stats)
.map(map -> (PercentageValues) map.get(DURATION))
.orElse(null);
}
}
Loading
Loading