From 19bdb8bb04fd60d4cb70c97b6484499abe1d15d1 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Fri, 13 Dec 2024 12:05:20 +0100 Subject: [PATCH] [OPIK-446] Add duration metric --- .../comet/opik/domain/ProjectMetricsDAO.java | 61 ++++++++ .../opik/domain/ProjectMetricsService.java | 30 ++-- .../java/com/comet/opik/domain/SpanDAO.java | 4 +- .../java/com/comet/opik/domain/TraceDAO.java | 4 +- .../000008_add_duration_columns.sql | 11 ++ .../v1/priv/ProjectMetricsResourceTest.java | 141 ++++++++++++++---- 6 files changed, 205 insertions(+), 46 deletions(-) create mode 100644 apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000008_add_duration_columns.sql diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsDAO.java index 76a20f878e..5ccfd08c08 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsDAO.java @@ -25,6 +25,7 @@ import java.util.Optional; import java.util.UUID; import java.util.function.Function; +import java.util.stream.Stream; import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToMono; import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.endSegment; @@ -35,11 +36,16 @@ public interface ProjectMetricsDAO { String NAME_TRACES = "traces"; String NAME_COST = "cost"; + String NAME_DURATION_P50 = "duration.p50"; + String NAME_DURATION_P90 = "duration.p90"; + String NAME_DURATION_P99 = "duration.p99"; @Builder record Entry(String name, Instant time, Number value) { + } + Mono> getDuration(UUID projectId, ProjectMetricRequest request); Mono> getTraceCount(@NonNull UUID projectId, @NonNull ProjectMetricRequest request); Mono> getFeedbackScores(@NonNull UUID projectId, @NonNull ProjectMetricRequest request); Mono> getTokenUsage(@NonNull UUID projectId, @NonNull ProjectMetricRequest request); @@ -50,6 +56,7 @@ record Entry(String name, Instant time, Number value) { @Singleton @RequiredArgsConstructor(onConstructor_ = @Inject) class ProjectMetricsDAOImpl implements ProjectMetricsDAO { + private final @NonNull TransactionTemplateAsync template; private static final Map INTERVAL_TO_SQL = Map.of( @@ -57,6 +64,25 @@ class ProjectMetricsDAOImpl implements ProjectMetricsDAO { TimeInterval.DAILY, "toIntervalDay(1)", TimeInterval.HOURLY, "toIntervalHour(1)"); + private static final String GET_TRACE_DURATION = """ + SELECT AS bucket, + arrayMap(v -> + toDecimal64(if(isNaN(v), 0, v), 9), + quantiles(0.5, 0.9, 0.99)(duration) + ) AS duration + FROM traces + WHERE project_id = :project_id + AND workspace_id = :workspace_id + AND start_time >= parseDateTime64BestEffort(:start_time, 9) + AND start_time \\<= parseDateTime64BestEffort(:end_time, 9) + GROUP BY bucket + ORDER BY bucket + WITH FILL + FROM + TO parseDateTimeBestEffort(:end_time) + STEP ; + """; + private static final String GET_TRACE_COUNT = """ SELECT AS bucket, nullIf(count(DISTINCT id), 0) as count @@ -151,6 +177,41 @@ TO parseDateTimeBestEffort(:end_time) STEP ; """; + @Override + public Mono> getDuration(@NonNull UUID projectId, @NonNull ProjectMetricRequest request) { + return template.nonTransaction(connection -> getMetric(projectId, request, connection, + GET_TRACE_DURATION, "traceDuration") + .flatMapMany(result -> result.map((row, metadata) -> mapDuration(row))) + .reduce(Stream::concat) + .map(Stream::toList)); + } + + private Stream mapDuration(Row row) { + return Optional.ofNullable(row.get("duration", List.class)) + .map(durations -> Stream.of( + Entry.builder().name(NAME_DURATION_P50) + .time(row.get("bucket", Instant.class)) + .value(getP(durations, 0)) + .build(), + Entry.builder().name(NAME_DURATION_P90) + .time(row.get("bucket", Instant.class)) + .value(getP(durations, 1)) + .build(), + Entry.builder().name(NAME_DURATION_P99) + .time(row.get("bucket", Instant.class)) + .value(getP(durations, 2)) + .build())) + .orElse(Stream.empty()); + } + + private static BigDecimal getP(List durations, int index) { + if (durations.size() <= index) { + return null; + } + + return (BigDecimal) durations.get(index); + } + @Override public Mono> getTraceCount(@NonNull UUID projectId, @NonNull ProjectMetricRequest request) { return template.nonTransaction(connection -> getMetric(projectId, request, connection, diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java index 49e3c1d729..1223dbc340 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ProjectMetricsService.java @@ -9,7 +9,6 @@ import jakarta.inject.Singleton; import jakarta.ws.rs.BadRequestException; import lombok.NonNull; -import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import reactor.core.publisher.Mono; @@ -30,15 +29,23 @@ public interface ProjectMetricsService { @Slf4j @Singleton -@RequiredArgsConstructor(onConstructor_ = @Inject) class ProjectMetricsServiceImpl implements ProjectMetricsService { - private final @NonNull ProjectMetricsDAO projectMetricsDAO; + + private final @NonNull Map>>> metricHandler; + + @Inject + public ProjectMetricsServiceImpl(@NonNull ProjectMetricsDAO projectMetricsDAO) { + metricHandler = Map.of( + MetricType.TRACE_COUNT, projectMetricsDAO::getTraceCount, + MetricType.FEEDBACK_SCORES, projectMetricsDAO::getFeedbackScores, + MetricType.TOKEN_USAGE, projectMetricsDAO::getTokenUsage, + MetricType.COST, projectMetricsDAO::getCost, + MetricType.DURATION, projectMetricsDAO::getDuration); + } @Override public Mono> getProjectMetrics(UUID projectId, ProjectMetricRequest request) { return getMetricHandler(request.metricType()) - .orElseThrow( - () -> new BadRequestException(ERR_PROJECT_METRIC_NOT_SUPPORTED.formatted(request.metricType()))) .apply(projectId, request) .map(dataPoints -> ProjectMetricResponse.builder() .projectId(projectId) @@ -71,15 +78,8 @@ private List> entriesToResults(List>>> getMetricHandler( - MetricType metricType) { - Map>>> HANDLER_BY_TYPE = Map - .of( - MetricType.TRACE_COUNT, projectMetricsDAO::getTraceCount, - MetricType.FEEDBACK_SCORES, projectMetricsDAO::getFeedbackScores, - MetricType.TOKEN_USAGE, projectMetricsDAO::getTokenUsage, - MetricType.COST, projectMetricsDAO::getCost); - - return Optional.ofNullable(HANDLER_BY_TYPE.get(metricType)); + private BiFunction>> getMetricHandler(MetricType metricType) { + return Optional.ofNullable(metricHandler.get(metricType)) + .orElseThrow(() -> new BadRequestException(ERR_PROJECT_METRIC_NOT_SUPPORTED.formatted(metricType))); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java index f5dd643f89..3383f98877 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java @@ -589,7 +589,7 @@ AND id in ( SELECT project_id as project_id, count(DISTINCT span_id) as span_count, - arrayMap(v -> if(isNaN(v), 0, toDecimal64(v / 1000.0, 9)), quantiles(0.5, 0.9, 0.99)(duration)) AS duration, + arrayMap(v -> toDecimal64(if(isNaN(v), 0, v), 9), quantiles(0.5, 0.9, 0.99)(duration)) AS duration, sum(input_count) as input, sum(output_count) as output, sum(metadata_count) as metadata, @@ -616,7 +616,7 @@ AND id in ( workspace_id, project_id, id, - if(end_time IS NOT NULL, date_diff('microsecond', start_time, end_time), null) as duration, + duration, if(length(input) > 0, 1, 0) as input_count, if(length(output) > 0, 1, 0) as output_count, if(length(metadata) > 0, 1, 0) as metadata_count, diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java index 0167fa8134..fdf446563b 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java @@ -570,7 +570,7 @@ LEFT JOIN ( SELECT project_id as project_id, count(DISTINCT trace_id) as trace_count, - arrayMap(v -> if(isNaN(v), 0, toDecimal64(v / 1000.0, 9)), quantiles(0.5, 0.9, 0.99)(duration)) AS duration, + arrayMap(v -> toDecimal64(if(isNaN(v), 0, v), 9), quantiles(0.5, 0.9, 0.99)(duration)) AS duration, sum(input_count) as input, sum(output_count) as output, sum(metadata_count) as metadata, @@ -597,7 +597,7 @@ LEFT JOIN ( workspace_id, project_id, id, - if(end_time IS NOT NULL, date_diff('microsecond', start_time, end_time), null) as duration, + duration, if(length(input) > 0, 1, 0) as input_count, if(length(output) > 0, 1, 0) as output_count, if(length(metadata) > 0, 1, 0) as metadata_count, diff --git a/apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000008_add_duration_columns.sql b/apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000008_add_duration_columns.sql new file mode 100644 index 0000000000..a41046ab3f --- /dev/null +++ b/apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000008_add_duration_columns.sql @@ -0,0 +1,11 @@ +--liquibase formatted sql +--changeset thiagohora:add_duration_columns + +ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.spans + ADD COLUMN IF NOT EXISTS duration Nullable(Float64) MATERIALIZED if(end_time IS NOT NULL, (dateDiff('microsecond', start_time, end_time) / 1000.0), NULL); + +ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.traces + ADD COLUMN IF NOT EXISTS duration Nullable(Float64) MATERIALIZED if(end_time IS NOT NULL, (dateDiff('microsecond', start_time, end_time) / 1000.0), NULL); + +--rollback ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.spans DROP COLUMN IF EXISTS duration; +--rollback ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.traces DROP COLUMN IF EXISTS duration; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java index f3836bda73..f20aaf6d6f 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectMetricsResourceTest.java @@ -32,9 +32,9 @@ import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; import org.apache.commons.lang3.RandomStringUtils; +import org.apache.commons.lang3.reflect.TypeUtils; import org.apache.hc.core5.http.HttpStatus; import org.jdbi.v3.core.Jdbi; -import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -55,7 +55,6 @@ import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; import uk.co.jemos.podam.api.PodamFactory; -import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.math.BigDecimal; import java.math.RoundingMode; @@ -296,6 +295,7 @@ void getProjectMetrics__whenSessionTokenIsPresent__thenReturnProperResponse( @DisplayName("Number of traces") @TestInstance(TestInstance.Lifecycle.PER_CLASS) class NumberOfTracesTest { + @ParameterizedTest @EnumSource(TimeInterval.class) void happyPath(TimeInterval interval) { @@ -362,11 +362,8 @@ public static Stream invalidParameters() { arguments(named("start equal to end", validReq.toBuilder() .intervalStart(now) .intervalEnd(now) - .build()), ProjectMetricsService.ERR_START_BEFORE_END), - arguments(named("not supported metric", validReq.toBuilder() - .metricType(MetricType.DURATION) - .build()), ProjectMetricsService.ERR_PROJECT_METRIC_NOT_SUPPORTED.formatted( - MetricType.DURATION))); + .build()), ProjectMetricsService.ERR_START_BEFORE_END) + ); } @ParameterizedTest @@ -409,6 +406,7 @@ private void createTraces(String projectName, Instant marker, int count) { @DisplayName("Feedback scores") @TestInstance(TestInstance.Lifecycle.PER_CLASS) class FeedbackScoresTest { + @ParameterizedTest @EnumSource(TimeInterval.class) void happyPath(TimeInterval interval) { @@ -499,6 +497,7 @@ private static BigDecimal calcAverage(List scores) { @DisplayName("Token usage") @TestInstance(TestInstance.Lifecycle.PER_CLASS) class TokenUsageTest { + @ParameterizedTest @EnumSource(TimeInterval.class) void happyPath(TimeInterval interval) { @@ -609,6 +608,7 @@ private void getAndAssertEmpty(UUID projectId, TimeInterval interval, Instant ma @DisplayName("Cost") @TestInstance(TestInstance.Lifecycle.PER_CLASS) class CostTest { + @ParameterizedTest @EnumSource(TimeInterval.class) void happyPath(TimeInterval interval) { @@ -686,6 +686,111 @@ private BigDecimal createSpans( } } + @Nested + @DisplayName("Duration") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class DurationTest { + + @ParameterizedTest + @EnumSource(TimeInterval.class) + void happyPath(TimeInterval interval) { + // setup + mockTargetWorkspace(); + + Instant marker = getIntervalStart(interval); + String projectName = RandomStringUtils.randomAlphabetic(10); + var projectId = projectResourceClient.createProject(projectName, API_KEY, WORKSPACE_NAME); + + List durationsMinus3 = createTraces(projectName, subtract(marker, TIME_BUCKET_3, interval)); + List durationsMinus1 = createTraces(projectName, subtract(marker, TIME_BUCKET_1, interval)); + List durationsCurrent = createTraces(projectName, marker); + + var durationMinus3 = Map.of( + ProjectMetricsDAO.NAME_DURATION_P50, durationsMinus3.get(0), + ProjectMetricsDAO.NAME_DURATION_P90, durationsMinus3.get(1), + ProjectMetricsDAO.NAME_DURATION_P99, durationsMinus3.getLast()); + var durationMinus1 = Map.of( + ProjectMetricsDAO.NAME_DURATION_P50, durationsMinus1.get(0), + ProjectMetricsDAO.NAME_DURATION_P90, durationsMinus1.get(1), + ProjectMetricsDAO.NAME_DURATION_P99, durationsMinus1.getLast()); + var durationCurrent = Map.of( + ProjectMetricsDAO.NAME_DURATION_P50, durationsCurrent.get(0), + ProjectMetricsDAO.NAME_DURATION_P90, durationsCurrent.get(1), + ProjectMetricsDAO.NAME_DURATION_P99, durationsCurrent.getLast()); + + getMetricsAndAssert( + projectId, + ProjectMetricRequest.builder() + .metricType(MetricType.DURATION) + .interval(interval) + .intervalStart(subtract(marker, TIME_BUCKET_4, interval)) + .intervalEnd(Instant.now()) + .build(), + marker, + List.of(ProjectMetricsDAO.NAME_DURATION_P50, ProjectMetricsDAO.NAME_DURATION_P90, ProjectMetricsDAO.NAME_DURATION_P99), + BigDecimal.class, + durationMinus3, + durationMinus1, + durationCurrent); + } + + @ParameterizedTest + @EnumSource(TimeInterval.class) + void emptyData(TimeInterval interval) { + // setup + mockTargetWorkspace(); + + Instant marker = getIntervalStart(interval); + String projectName = RandomStringUtils.randomAlphabetic(10); + var projectId = projectResourceClient.createProject(projectName, API_KEY, WORKSPACE_NAME); + + Map empty = new HashMap<>() { + { + put(ProjectMetricsDAO.NAME_DURATION_P50, null); + put(ProjectMetricsDAO.NAME_DURATION_P90, null); + put(ProjectMetricsDAO.NAME_DURATION_P99, null); + } + }; + + getMetricsAndAssert( + projectId, + ProjectMetricRequest.builder() + .metricType(MetricType.DURATION) + .interval(interval) + .intervalStart(subtract(marker, TIME_BUCKET_4, interval)) + .intervalEnd(Instant.now()) + .build(), + marker, + List.of(ProjectMetricsDAO.NAME_DURATION_P50, ProjectMetricsDAO.NAME_DURATION_P90, ProjectMetricsDAO.NAME_DURATION_P99), + BigDecimal.class, + empty, + empty, + empty + ); + } + + private List createTraces(String projectName, Instant marker) { + List traces = IntStream.range(0, 5) + .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .startTime(marker) + .endTime(marker.plusMillis(RANDOM.nextInt(1000))) + .build()) + .toList(); + + traceResourceClient.batchCreateTraces(traces, API_KEY, WORKSPACE_NAME); + + return StatsUtils.calculateQuantiles( + traces.stream() + .filter(entity -> entity.endTime() != null) + .map(entity -> entity.startTime().until(entity.endTime(), ChronoUnit.MICROS)) + .map(duration -> duration / 1_000.0) + .toList(), + List.of(0.50, 0.90, 0.99)); + } + + } + private ProjectMetricResponse getProjectMetrics( UUID projectId, ProjectMetricRequest request, Class aClass) { try (var response = client.target(URL_TEMPLATE.formatted(baseURI, projectId)) @@ -697,7 +802,8 @@ private ProjectMetricResponse getProjectMetrics( assertThat(response.getStatus()).isEqualTo(HttpStatus.SC_OK); assertThat(response.hasEntity()).isTrue(); - return response.readEntity(new GenericType<>(createParameterizedProjectMetricResponse(aClass))); + Type parameterize = TypeUtils.parameterize(ProjectMetricResponse.class, aClass); + return response.readEntity(new GenericType<>(parameterize)); } } @@ -722,25 +828,6 @@ private void getMetricsAndAssert( .isEqualTo(expected); } - private static Type createParameterizedProjectMetricResponse(Class genericArgument) { - return new ParameterizedType() { - @NotNull @Override - public Type[] getActualTypeArguments() { - return new Type[]{genericArgument}; - } - - @NotNull @Override - public Type getRawType() { - return ProjectMetricResponse.class; - } - - @Override - public Type getOwnerType() { - return null; - } - }; - } - private static List> createExpected( Instant marker, TimeInterval interval, List names, Map dataMinus3, Map dataMinus1, Map dataNow) {