Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-446] Add duration metric #889

Merged
merged 12 commits into from
Dec 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Optional;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Stream;

import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToMono;
import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.endSegment;
Expand All @@ -35,11 +36,16 @@
public interface ProjectMetricsDAO {
String NAME_TRACES = "traces";
String NAME_COST = "cost";
String NAME_DURATION_P50 = "duration.p50";
String NAME_DURATION_P90 = "duration.p90";
String NAME_DURATION_P99 = "duration.p99";

@Builder
record Entry(String name, Instant time, Number value) {

}

Mono<List<Entry>> getDuration(UUID projectId, ProjectMetricRequest request);
Mono<List<Entry>> getTraceCount(@NonNull UUID projectId, @NonNull ProjectMetricRequest request);
Mono<List<Entry>> getFeedbackScores(@NonNull UUID projectId, @NonNull ProjectMetricRequest request);
Mono<List<Entry>> getTokenUsage(@NonNull UUID projectId, @NonNull ProjectMetricRequest request);
Expand All @@ -57,6 +63,28 @@ class ProjectMetricsDAOImpl implements ProjectMetricsDAO {
TimeInterval.DAILY, "toIntervalDay(1)",
TimeInterval.HOURLY, "toIntervalHour(1)");

private static final String GET_TRACE_DURATION = """
SELECT <bucket> AS bucket,
arrayMap(v ->
toDecimal64(if(isNaN(v), 0, v), 9),
quantiles(0.5, 0.9, 0.99)(if(end_time IS NOT NULL AND start_time IS NOT NULL
AND notEquals(start_time, toDateTime64('1970-01-01 00:00:00.000', 9)),
(dateDiff('microsecond', start_time, end_time) / 1000.0),
NULL))
) AS duration
FROM traces
WHERE project_id = :project_id
AND workspace_id = :workspace_id
AND start_time >= parseDateTime64BestEffort(:start_time, 9)
AND start_time \\<= parseDateTime64BestEffort(:end_time, 9)
GROUP BY bucket
ORDER BY bucket
WITH FILL
FROM <fill_from>
TO parseDateTimeBestEffort(:end_time)
STEP <step>;
""";

private static final String GET_TRACE_COUNT = """
SELECT <bucket> AS bucket,
nullIf(count(DISTINCT id), 0) as count
Expand Down Expand Up @@ -151,6 +179,41 @@ TO parseDateTimeBestEffort(:end_time)
STEP <step>;
""";

@Override
public Mono<List<Entry>> getDuration(@NonNull UUID projectId, @NonNull ProjectMetricRequest request) {
return template.nonTransaction(connection -> getMetric(projectId, request, connection,
GET_TRACE_DURATION, "traceDuration")
.flatMapMany(result -> result.map((row, metadata) -> mapDuration(row)))
.reduce(Stream::concat)
.map(Stream::toList));
}

private Stream<Entry> mapDuration(Row row) {
return Optional.ofNullable(row.get("duration", List.class))
.map(durations -> Stream.of(
Entry.builder().name(NAME_DURATION_P50)
.time(row.get("bucket", Instant.class))
.value(getP(durations, 0))
.build(),
Entry.builder().name(NAME_DURATION_P90)
.time(row.get("bucket", Instant.class))
.value(getP(durations, 1))
.build(),
Entry.builder().name(NAME_DURATION_P99)
.time(row.get("bucket", Instant.class))
.value(getP(durations, 2))
.build()))
.orElse(Stream.empty());
}

private static BigDecimal getP(List durations, int index) {
if (durations.size() <= index) {
return null;
}

return (BigDecimal) durations.get(index);
}

@Override
public Mono<List<Entry>> getTraceCount(@NonNull UUID projectId, @NonNull ProjectMetricRequest request) {
return template.nonTransaction(connection -> getMetric(projectId, request, connection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
import com.google.inject.ImplementedBy;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.BadRequestException;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Mono;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
Expand All @@ -30,15 +27,22 @@ public interface ProjectMetricsService {

@Slf4j
@Singleton
@RequiredArgsConstructor(onConstructor_ = @Inject)
class ProjectMetricsServiceImpl implements ProjectMetricsService {
private final @NonNull ProjectMetricsDAO projectMetricsDAO;
private final @NonNull Map<MetricType, BiFunction<UUID, ProjectMetricRequest, Mono<List<ProjectMetricsDAO.Entry>>>> metricHandler;

@Inject
public ProjectMetricsServiceImpl(@NonNull ProjectMetricsDAO projectMetricsDAO) {
metricHandler = Map.of(
MetricType.TRACE_COUNT, projectMetricsDAO::getTraceCount,
MetricType.FEEDBACK_SCORES, projectMetricsDAO::getFeedbackScores,
MetricType.TOKEN_USAGE, projectMetricsDAO::getTokenUsage,
MetricType.COST, projectMetricsDAO::getCost,
MetricType.DURATION, projectMetricsDAO::getDuration);
}

@Override
public Mono<ProjectMetricResponse<Number>> getProjectMetrics(UUID projectId, ProjectMetricRequest request) {
return getMetricHandler(request.metricType())
.orElseThrow(
() -> new BadRequestException(ERR_PROJECT_METRIC_NOT_SUPPORTED.formatted(request.metricType())))
.apply(projectId, request)
.map(dataPoints -> ProjectMetricResponse.builder()
.projectId(projectId)
Expand Down Expand Up @@ -71,15 +75,8 @@ private List<ProjectMetricResponse.Results<Number>> entriesToResults(List<Projec
.toList();
}

private Optional<BiFunction<UUID, ProjectMetricRequest, Mono<List<ProjectMetricsDAO.Entry>>>> getMetricHandler(
private BiFunction<UUID, ProjectMetricRequest, Mono<List<ProjectMetricsDAO.Entry>>> getMetricHandler(
MetricType metricType) {
Map<MetricType, BiFunction<UUID, ProjectMetricRequest, Mono<List<ProjectMetricsDAO.Entry>>>> HANDLER_BY_TYPE = Map
.of(
MetricType.TRACE_COUNT, projectMetricsDAO::getTraceCount,
MetricType.FEEDBACK_SCORES, projectMetricsDAO::getFeedbackScores,
MetricType.TOKEN_USAGE, projectMetricsDAO::getTokenUsage,
MetricType.COST, projectMetricsDAO::getCost);

return Optional.ofNullable(HANDLER_BY_TYPE.get(metricType));
return metricHandler.get(metricType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,11 @@ AND id in (
FROM
(
SELECT
id
id,
if(end_time IS NOT NULL AND start_time IS NOT NULL
AND notEquals(start_time, toDateTime64('1970-01-01 00:00:00.000', 9)),
(dateDiff('microsecond', start_time, end_time) / 1000.0),
NULL) AS duration_millis
FROM spans
WHERE project_id = :project_id
AND workspace_id = :workspace_id
Expand Down Expand Up @@ -1211,6 +1215,7 @@ private void bindSearchCriteria(Statement statement, SpanSearchCriteria spanSear
.ifPresent(filters -> {
filterQueryBuilder.bind(statement, filters, FilterStrategy.SPAN);
filterQueryBuilder.bind(statement, filters, FilterStrategy.FEEDBACK_SCORES);
filterQueryBuilder.bind(statement, filters, FilterStrategy.DURATION);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,11 @@ WHERE created_at BETWEEN toStartOfDay(yesterday()) AND toStartOfDay(today())
sum(s.total_estimated_cost) as total_estimated_cost
FROM (
SELECT
id
id,
if(end_time IS NOT NULL AND start_time IS NOT NULL
AND notEquals(start_time, toDateTime64('1970-01-01 00:00:00.000', 9)),
(dateDiff('microsecond', start_time, end_time) / 1000.0),
NULL) AS duration_millis
FROM traces
WHERE project_id = :project_id
AND workspace_id = :workspace_id
Expand Down Expand Up @@ -1052,6 +1056,7 @@ private void bindSearchCriteria(TraceSearchCriteria traceSearchCriteria, Stateme
filterQueryBuilder.bind(statement, filters, FilterStrategy.TRACE);
filterQueryBuilder.bind(statement, filters, FilterStrategy.TRACE_AGGREGATION);
filterQueryBuilder.bind(statement, filters, FilterStrategy.FEEDBACK_SCORES);
filterQueryBuilder.bind(statement, filters, FilterStrategy.DURATION);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,7 @@ public class FilterQueryBuilder {
private static final String USAGE_PROMPT_TOKENS_ANALYTICS_DB = "usage['prompt_tokens']";
private static final String USAGE_TOTAL_TOKENS_ANALYTICS_DB = "usage['total_tokens']";
private static final String VALUE_ANALYTICS_DB = "value";
private static final String DURATION_ANALYTICS_DB = """
if(end_time IS NOT NULL AND start_time IS NOT NULL
AND notEquals(start_time, toDateTime64('1970-01-01 00:00:00.000', 9)),
(dateDiff('microsecond', start_time, end_time) / 1000.0),
NULL)
""";
private static final String DURATION_ANALYTICS_DB = "duration_millis";

private static final Map<Operator, Map<FieldType, String>> ANALYTICS_DB_OPERATOR_MAP = new EnumMap<>(Map.of(
Operator.CONTAINS, new EnumMap<>(Map.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ public enum FilterStrategy {
SPAN,
EXPERIMENT_ITEM,
DATASET_ITEM,
FEEDBACK_SCORES;
FEEDBACK_SCORES,
DURATION;

public static final String DYNAMIC_FIELD = ":dynamicField%1$d";

Expand Down
Loading
Loading