Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-415] Compute traces cost based on token usage #703

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jakarta.validation.constraints.Pattern;
import lombok.Builder;

import java.math.BigDecimal;
import java.time.Instant;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -42,7 +43,9 @@ public record Trace(
@JsonView({Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy,
@JsonView({Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy,
@JsonView({
Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List<FeedbackScore> feedbackScores){
Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List<FeedbackScore> feedbackScores,
@JsonView({
Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) BigDecimal totalEstimatedCost){

public record TracePage(
@JsonView(Trace.View.Public.class) int page,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ AND id in (
""";

private static final String ESTIMATED_COST_VERSION = "1.0";
private static final BigDecimal ZERO_COST = new BigDecimal("0.00000000");

private final @NonNull ConnectionFactory connectionFactory;
private final @NonNull FeedbackScoreDAO feedbackScoreDAO;
Expand Down Expand Up @@ -634,7 +633,7 @@ private Publisher<? extends Result> insert(List<Span> spans, Connection connecti
.bind("provider" + i, span.provider() != null ? span.provider() : "")
.bind("total_estimated_cost" + i, estimatedCost.toString())
.bind("total_estimated_cost_version" + i,
estimatedCost.compareTo(ZERO_COST) > 0 ? ESTIMATED_COST_VERSION : "")
estimatedCost.compareTo(BigDecimal.ZERO) > 0 ? ESTIMATED_COST_VERSION : "")
.bind("tags" + i, span.tags() != null ? span.tags().toArray(String[]::new) : new String[]{})
.bind("created_by" + i, userName)
.bind("last_updated_by" + i, userName);
Expand Down Expand Up @@ -720,7 +719,7 @@ private Publisher<? extends Result> insert(Span span, Connection connection) {

BigDecimal estimatedCost = calculateCost(span);
statement.bind("total_estimated_cost", estimatedCost.toString());
if (estimatedCost.compareTo(ZERO_COST) > 0) {
if (estimatedCost.compareTo(BigDecimal.ZERO) > 0) {
statement.bind("total_estimated_cost_version", ESTIMATED_COST_VERSION);
} else {
statement.bind("total_estimated_cost_version", "");
Expand Down Expand Up @@ -939,7 +938,7 @@ private Publisher<Span> mapToDto(Result result) {
.orElse(null))
.model(row.get("model", String.class))
.provider(row.get("provider", String.class))
.totalEstimatedCost(row.get("total_estimated_cost", BigDecimal.class).equals(ZERO_COST)
.totalEstimatedCost(row.get("total_estimated_cost", BigDecimal.class).compareTo(BigDecimal.ZERO) == 0
? null
: row.get("total_estimated_cost", BigDecimal.class))
.tags(Optional.of(Arrays.stream(row.get("tags", String[].class)).collect(Collectors.toSet()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;

import java.math.BigDecimal;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -250,7 +251,8 @@ INSERT INTO traces (
private static final String SELECT_BY_ID = """
SELECT
t.*,
sumMap(s.usage) as usage
sumMap(s.usage) as usage,
sum(s.total_estimated_cost) as total_estimated_cost
FROM (
SELECT
*
Expand All @@ -263,7 +265,8 @@ INSERT INTO traces (
LEFT JOIN (
SELECT
trace_id,
usage
usage,
total_estimated_cost
FROM spans
WHERE workspace_id = :workspace_id
AND trace_id = :id
Expand All @@ -279,7 +282,8 @@ LEFT JOIN (
private static final String SELECT_BY_PROJECT_ID = """
SELECT
t.*,
sumMap(s.usage) as usage
sumMap(s.usage) as usage,
sum(s.total_estimated_cost) as total_estimated_cost
FROM (
SELECT
id,
Expand Down Expand Up @@ -324,7 +328,8 @@ AND id in (
LEFT JOIN (
SELECT
trace_id,
usage
usage,
total_estimated_cost
FROM spans
WHERE workspace_id = :workspace_id
AND project_id = :project_id
Expand Down Expand Up @@ -747,6 +752,9 @@ private Publisher<Trace> mapToDto(Result result) {
.filter(it -> !it.isEmpty())
.orElse(null))
.usage(row.get("usage", Map.class))
.totalEstimatedCost(row.get("total_estimated_cost", BigDecimal.class).compareTo(BigDecimal.ZERO) == 0
? null
: row.get("total_estimated_cost", BigDecimal.class))
.createdAt(row.get("created_at", Instant.class))
.lastUpdatedAt(row.get("last_updated_at", Instant.class))
.createdBy(row.get("created_by", String.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ public static BigDecimal textGenerationCost(ModelPrice modelPrice, Map<String, I
}

public static BigDecimal defaultCost(ModelPrice modelPrice, Map<String, Integer> usage) {
return new BigDecimal("0");
return BigDecimal.ZERO;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.comet.opik.api.resources.utils.resources.TraceResourceClient;
import com.comet.opik.domain.FeedbackScoreMapper;
import com.comet.opik.domain.SpanType;
import com.comet.opik.domain.cost.ModelPrice;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.podam.PodamFactoryUtils;
import com.comet.opik.utils.JsonUtils;
Expand Down Expand Up @@ -108,7 +109,7 @@ class TracesResourceTest {
public static final String URL_TEMPLATE = "%s/v1/private/traces";
private static final String URL_TEMPLATE_SPANS = "%s/v1/private/spans";
private static final String[] IGNORED_FIELDS_TRACES = {"projectId", "projectName", "createdAt",
"lastUpdatedAt", "feedbackScores", "createdBy", "lastUpdatedBy"};
"lastUpdatedAt", "feedbackScores", "createdBy", "lastUpdatedBy", "totalEstimatedCost"};
private static final String[] IGNORED_FIELDS_SPANS = SpansResourceTest.IGNORED_FIELDS;
private static final String[] IGNORED_FIELDS_SCORES = {"createdAt", "lastUpdatedAt", "createdBy", "lastUpdatedBy"};

Expand Down Expand Up @@ -3224,6 +3225,54 @@ void getTraceWithUsage() {
getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE);
}

@ParameterizedTest
@MethodSource
void getTraceWithCost(String model) {
var projectName = RandomStringUtils.randomAlphanumeric(10);
var trace = factory.manufacturePojo(Trace.class)
.toBuilder()
.id(null)
.projectName(projectName)
.feedbackScores(null)
.build();
var id = create(trace, API_KEY, TEST_WORKSPACE);

var spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream()
.map(spanInStream -> spanInStream.toBuilder()
.projectName(projectName)
.traceId(id)
.usage(Map.of("completion_tokens", Math.abs(factory.manufacturePojo(Integer.class)),
"prompt_tokens", Math.abs(factory.manufacturePojo(Integer.class))))
.model(model)
.build())
.collect(Collectors.toList());

var usage = spans.stream()
.flatMap(span -> span.usage().entrySet().stream())
.map(entry -> new AbstractMap.SimpleEntry<>(entry.getKey(), Long.valueOf(entry.getValue())))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, Long::sum));

BigDecimal traceExpectedCost = spans.stream()
.map(span -> ModelPrice.fromString(span.model()).calculateCost(span.usage()))
.reduce(BigDecimal.ZERO, BigDecimal::add);

batchCreateSpansAndAssert(spans, API_KEY, TEST_WORKSPACE);

var projectId = getProjectId(projectName, TEST_WORKSPACE, API_KEY);
trace = trace.toBuilder().id(id).usage(usage).build();
Trace createdTrace = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE);
assertThat(traceExpectedCost.compareTo(BigDecimal.ZERO) == 0 ?
createdTrace.totalEstimatedCost() == null :
traceExpectedCost.compareTo(createdTrace.totalEstimatedCost()) == 0)
.isEqualTo(true);
Comment on lines +3264 to +3267
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assertThat(traceExpectedCost.compareTo(BigDecimal.ZERO) == 0 ?
createdTrace.totalEstimatedCost() == null :
traceExpectedCost.compareTo(createdTrace.totalEstimatedCost()) == 0)
.isEqualTo(true);
var actual = createdTrace.totalEstimatedCost();
traceExpectedCost = traceExpectedCost.compareTo(BigDecimal.ZERO) == 0 ? null : traceExpectedCost;
assertThat(actual)
.usingRecursiveComparison(RecursiveComparisonConfiguration.builder()
.withComparatorForType(BigDecimal::compareTo, BigDecimal.class)
.build())
.isEqualTo(traceExpectedCost);

}

static Stream<Arguments> getTraceWithCost() {
return Stream.of(
Arguments.of("gpt-3.5-turbo-1106"),
Arguments.of("unknown-model"));
}

@Test
void getTraceWithoutUsage() {
var apiKey = UUID.randomUUID().toString();
Expand Down