Skip to content

Commit

Permalink
Change approach to persist cost in DB
Browse files Browse the repository at this point in the history
  • Loading branch information
Borys Tkachenko committed Nov 20, 2024
1 parent 433ae94 commit a4cb6eb
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 29 deletions.
4 changes: 3 additions & 1 deletion apps/opik-backend/src/main/java/com/comet/opik/api/Span.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import jakarta.validation.constraints.Pattern;
import lombok.Builder;

import java.math.BigDecimal;
import java.time.Instant;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -49,7 +50,8 @@ public record Span(
@JsonView({Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy,
@JsonView({
Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List<FeedbackScore> feedbackScores,
@JsonView({Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Double totalEstimatedCost){
@JsonView({
Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) BigDecimal totalEstimatedCost){

public record SpanPage(
@JsonView(Span.View.Public.class) int page,
Expand Down
92 changes: 74 additions & 18 deletions apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.stringtemplate.v4.ST;
import reactor.core.publisher.Mono;

import java.math.BigDecimal;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -70,6 +71,8 @@ INSERT INTO spans(
metadata,
model,
provider,
total_estimated_cost,
total_estimated_cost_version,
tags,
usage,
created_by,
Expand All @@ -91,6 +94,8 @@ INSERT INTO spans(
:metadata<item.index>,
:model<item.index>,
:provider<item.index>,
toDecimal32(:total_estimated_cost<item.index>, 6),
:total_estimated_cost_version<item.index>,
:tags<item.index>,
mapFromArrays(:usage_keys<item.index>, :usage_values<item.index>),
:created_by<item.index>,
Expand Down Expand Up @@ -123,6 +128,8 @@ INSERT INTO spans(
metadata,
model,
provider,
total_estimated_cost,
total_estimated_cost_version,
tags,
usage,
created_at,
Expand Down Expand Up @@ -187,6 +194,14 @@ INSERT INTO spans(
LENGTH(old_span.provider) > 0, old_span.provider,
new_span.provider
) as provider,
multiIf(
old_span.total_estimated_cost > toDecimal32(0, 6), old_span.total_estimated_cost,
new_span.total_estimated_cost
) as total_estimated_cost,
multiIf(
LENGTH(old_span.total_estimated_cost_version) > 0, old_span.total_estimated_cost_version,
new_span.total_estimated_cost_version
) as total_estimated_cost_version,
multiIf(
notEmpty(old_span.tags), old_span.tags,
new_span.tags
Expand Down Expand Up @@ -220,6 +235,8 @@ INSERT INTO spans(
:metadata as metadata,
:model as model,
:provider as provider,
toDecimal32(:total_estimated_cost, 6) as total_estimated_cost,
:total_estimated_cost_version as total_estimated_cost_version,
:tags as tags,
mapFromArrays(:usage_keys, :usage_values) as usage,
now64(9) as created_at,
Expand Down Expand Up @@ -258,6 +275,8 @@ INSERT INTO spans (
metadata,
model,
provider,
total_estimated_cost,
total_estimated_cost_version,
tags,
usage,
created_at,
Expand All @@ -278,6 +297,8 @@ INSERT INTO spans (
<if(metadata)> :metadata <else> metadata <endif> as metadata,
<if(model)> :model <else> model <endif> as model,
<if(provider)> :provider <else> provider <endif> as provider,
<if(total_estimated_cost)> toDecimal32(:total_estimated_cost, 6) <else> total_estimated_cost <endif> as total_estimated_cost,
<if(total_estimated_cost_version)> :total_estimated_cost_version <else> total_estimated_cost_version <endif> as total_estimated_cost_version,
<if(tags)> :tags <else> tags <endif> as tags,
<if(usage)> CAST((:usageKeys, :usageValues), 'Map(String, Int64)') <else> usage <endif> as usage,
created_at,
Expand All @@ -304,7 +325,7 @@ INSERT INTO spans (
private static final String PARTIAL_INSERT = """
INSERT INTO spans(
id, project_id, workspace_id, trace_id, parent_span_id, name, type,
start_time, end_time, input, output, metadata, model, provider, tags, usage, created_at,
start_time, end_time, input, output, metadata, model, provider, total_estimated_cost, total_estimated_cost_version, tags, usage, created_at,
created_by, last_updated_by
)
SELECT
Expand Down Expand Up @@ -371,6 +392,16 @@ INSERT INTO spans(
LENGTH(old_span.provider) > 0, old_span.provider,
new_span.provider
) as provider,
multiIf(
new_span.total_estimated_cost > toDecimal32(0, 6), new_span.total_estimated_cost,
old_span.total_estimated_cost > toDecimal32(0, 6), old_span.total_estimated_cost,
new_span.total_estimated_cost
) as total_estimated_cost,
multiIf(
LENGTH(new_span.total_estimated_cost_version) > 0, new_span.total_estimated_cost_version,
LENGTH(old_span.total_estimated_cost_version) > 0, old_span.total_estimated_cost_version,
new_span.total_estimated_cost_version
) as total_estimated_cost_version,
multiIf(
notEmpty(new_span.tags), new_span.tags,
notEmpty(old_span.tags), old_span.tags,
Expand Down Expand Up @@ -406,6 +437,8 @@ INSERT INTO spans(
<if(metadata)> :metadata <else> '' <endif> as metadata,
<if(model)> :model <else> '' <endif> as model,
<if(provider)> :provider <else> '' <endif> as provider,
<if(total_estimated_cost)> toDecimal32(:total_estimated_cost, 6) <else> toDecimal32(0, 6) <endif> as total_estimated_cost,
<if(total_estimated_cost_version)> :total_estimated_cost_version <else> '' <endif> as total_estimated_cost_version,
<if(tags)> :tags <else> [] <endif> as tags,
<if(usage)> CAST((:usageKeys, :usageValues), 'Map(String, Int64)') <else> mapFromArrays([], []) <endif> as usage,
now64(9) as created_at,
Expand Down Expand Up @@ -452,6 +485,7 @@ LEFT JOIN (
<if(truncate)> replaceRegexpAll(metadata, '<truncate>', '"[image]"') as metadata <else> metadata <endif>,
model,
provider,
total_estimated_cost,
tags,
usage,
created_at,
Expand Down Expand Up @@ -546,6 +580,9 @@ AND id in (
LIMIT 1 BY id
""";

private static final String ESTIMATED_COST_VERSION = "1.0";
private static final BigDecimal COST_NOT_AVAILABLE = new BigDecimal("0.000000");

private final @NonNull ConnectionFactory connectionFactory;
private final @NonNull FeedbackScoreDAO feedbackScoreDAO;
private final @NonNull FilterQueryBuilder filterQueryBuilder;
Expand Down Expand Up @@ -581,6 +618,8 @@ private Publisher<? extends Result> insert(List<Span> spans, Connection connecti
int i = 0;
for (Span span : spans) {

double estimatedCost = calculateCost(span);

statement.bind("id" + i, span.id())
.bind("project_id" + i, span.projectId())
.bind("trace_id" + i, span.traceId())
Expand All @@ -593,6 +632,8 @@ private Publisher<? extends Result> insert(List<Span> spans, Connection connecti
.bind("metadata" + i, span.metadata() != null ? span.metadata().toString() : "")
.bind("model" + i, span.model() != null ? span.model() : "")
.bind("provider" + i, span.provider() != null ? span.provider() : "")
.bind("total_estimated_cost" + i, estimatedCost)
.bind("total_estimated_cost_version" + i, estimatedCost > 0 ? ESTIMATED_COST_VERSION : "")
.bind("tags" + i, span.tags() != null ? span.tags().toArray(String[]::new) : new String[]{})
.bind("created_by" + i, userName)
.bind("last_updated_by" + i, userName);
Expand Down Expand Up @@ -675,6 +716,15 @@ private Publisher<? extends Result> insert(Span span, Connection connection) {
} else {
statement.bind("provider", "");
}

double estimatedCost = calculateCost(span);
statement.bind("total_estimated_cost", estimatedCost);
if (estimatedCost > 0) {
statement.bind("total_estimated_cost_version", ESTIMATED_COST_VERSION);
} else {
statement.bind("total_estimated_cost_version", "");
}

if (span.tags() != null) {
statement.bind("tags", span.tags().toArray(String[]::new));
} else {
Expand Down Expand Up @@ -788,6 +838,12 @@ private void bindUpdateParams(SpanUpdate spanUpdate, Statement statement) {
.ifPresent(model -> statement.bind("model", model));
Optional.ofNullable(spanUpdate.provider())
.ifPresent(provider -> statement.bind("provider", provider));

if (StringUtils.isNotBlank(spanUpdate.model()) && Objects.nonNull(spanUpdate.usage())) {
statement.bind("total_estimated_cost",
ModelPrice.fromString(spanUpdate.model()).calculateCost(spanUpdate.usage()));
statement.bind("total_estimated_cost_version", ESTIMATED_COST_VERSION);
}
}

private ST newUpdateTemplate(SpanUpdate spanUpdate, String sql) {
Expand All @@ -808,6 +864,10 @@ private ST newUpdateTemplate(SpanUpdate spanUpdate, String sql) {
.ifPresent(endTime -> template.add("end_time", endTime.toString()));
Optional.ofNullable(spanUpdate.usage())
.ifPresent(usage -> template.add("usage", usage.toString()));
if (StringUtils.isNotBlank(spanUpdate.model()) && Objects.nonNull(spanUpdate.usage())) {
template.add("total_estimated_cost", "total_estimated_cost");
template.add("total_estimated_cost_version", "total_estimated_cost_version");
}
return template;
}

Expand All @@ -817,8 +877,7 @@ public Mono<Span> getById(@NonNull UUID id) {
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> getById(id, connection))
.flatMap(this::mapToDto)
.flatMap(span -> enhanceWithFeedbackScores(List.of(span)).map(this::enhanceWithSpanCost)
.map(List::getFirst))
.flatMap(span -> enhanceWithFeedbackScores(List.of(span)).map(List::getFirst))
.singleOrEmpty();
}

Expand Down Expand Up @@ -879,6 +938,9 @@ private Publisher<Span> mapToDto(Result result) {
.orElse(null))
.model(row.get("model", String.class))
.provider(row.get("provider", String.class))
.totalEstimatedCost(row.get("total_estimated_cost", BigDecimal.class).equals(COST_NOT_AVAILABLE)
? null
: row.get("total_estimated_cost", BigDecimal.class))
.tags(Optional.of(Arrays.stream(row.get("tags", String[].class)).collect(Collectors.toSet()))
.filter(set -> !set.isEmpty())
.orElse(null))
Expand All @@ -903,7 +965,6 @@ private Mono<Span.SpanPage> find(int page, int size, SpanSearchCriteria spanSear
.flatMap(this::mapToDto)
.collectList()
.flatMap(this::enhanceWithFeedbackScores)
.map(this::enhanceWithSpanCost)
.map(spans -> new Span.SpanPage(page, spans.size(), total, spans));
}

Expand All @@ -919,20 +980,15 @@ private Mono<List<Span>> enhanceWithFeedbackScores(List<Span> spans) {
.doFinally(signalType -> endSegment(segment));
}

private List<Span> enhanceWithSpanCost(List<Span> spans) {
return spans.stream()
.map(span -> {
// Later we could just use span.model(), but now it's still located inside metadata
String model = StringUtils.isNotBlank(span.model())
? span.model()
: Optional.ofNullable(span.metadata())
.map(metadata -> metadata.get("model"))
.map(JsonNode::asText).orElse("");
return span.toBuilder()
.totalEstimatedCost(ModelPrice.fromString(model).calculateCost(span.usage()))
.build();
})
.toList();
private double calculateCost(Span span) {
// Later we could just use span.model(), but now it's still located inside metadata
String model = StringUtils.isNotBlank(span.model())
? span.model()
: Optional.ofNullable(span.metadata())
.map(metadata -> metadata.get("model"))
.map(JsonNode::asText).orElse("");

return ModelPrice.fromString(model).calculateCost(span.usage());
}

private Publisher<? extends Result> find(int page, int size, SpanSearchCriteria spanSearchCriteria,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.StringUtils;

import java.util.Arrays;
import java.util.Map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ public static double textGenerationCost(ModelPrice modelPrice, Map<String, Integ
}

public static double defaultCost(ModelPrice modelPrice, Map<String, Integer> usage) {
return -1;
return 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.spans
ADD COLUMN IF NOT EXISTS model String DEFAULT '',
ADD COLUMN IF NOT EXISTS provider String DEFAULT '';
ADD COLUMN IF NOT EXISTS provider String DEFAULT '',
ADD COLUMN IF NOT EXISTS total_estimated_cost Decimal32(6) DEFAULT toDecimal32(0, 6),
ADD COLUMN IF NOT EXISTS total_estimated_cost_version String DEFAULT '';

--rollback ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.spans DROP COLUMN IF EXISTS model, DROP COLUMN IF EXISTS provider;
--rollback ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.spans DROP COLUMN IF EXISTS model, DROP COLUMN IF EXISTS provider, DROP COLUMN IF EXISTS total_estimated_cost, DROP COLUMN IF EXISTS total_estimated_cost_version;
Original file line number Diff line number Diff line change
Expand Up @@ -3199,7 +3199,7 @@ void createAndGetById() {

@ParameterizedTest
@MethodSource
void createAndGetCost(Double expectedCost, String model, JsonNode metadata) {
void createAndGetCost(BigDecimal expectedCost, String model, JsonNode metadata) {
var expectedSpan = podamFactory.manufacturePojo(Span.class).toBuilder()
.model(model)
.metadata(metadata)
Expand All @@ -3217,11 +3217,11 @@ static Stream<Arguments> createAndGetCost() {
.getJsonNodeFromString(
"{\"created_from\":\"openai\",\"type\":\"openai_chat\",\"model\":\"gpt-3.5-turbo\"}");
return Stream.of(
Arguments.of(10.0, "gpt-3.5-turbo-1106", null),
Arguments.of(10.0, "gpt-3.5-turbo-1106", metadata),
Arguments.of(12.0, "", metadata),
Arguments.of(-1.0, "unknown-model", null),
Arguments.of(-1.0, "", null));
Arguments.of(new BigDecimal("10.000000"), "gpt-3.5-turbo-1106", null),
Arguments.of(new BigDecimal("10.000000"), "gpt-3.5-turbo-1106", metadata),
Arguments.of(new BigDecimal("12.000000"), "", metadata),
Arguments.of(null, "unknown-model", null),
Arguments.of(null, "", null));
}

@Test
Expand Down

0 comments on commit a4cb6eb

Please sign in to comment.