Skip to content

Commit

Permalink
OPIK-75: Add batch spans creation endpoint (#205)
Browse files Browse the repository at this point in the history
* [NA] Remove creation time from sort key

* OPIK-75: Add bulk spans creation endpoint

* Add tests

* Add code review suggestions

* PR review feedback

* Remove remaining dead code

* Fix tests
  • Loading branch information
thiagohora authored Sep 13, 2024
1 parent 102439a commit 7b6527c
Show file tree
Hide file tree
Showing 11 changed files with 496 additions and 156 deletions.
12 changes: 12 additions & 0 deletions apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonView;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;

import java.util.List;

public record SpanBatch(@NotNull @Size(min = 1, max = 1000) @JsonView( {
Span.View.Write.class}) @Valid List<Span> spans){
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.comet.opik.api.FeedbackScore;
import com.comet.opik.api.FeedbackScoreBatch;
import com.comet.opik.api.Span;
import com.comet.opik.api.SpanBatch;
import com.comet.opik.api.SpanSearchCriteria;
import com.comet.opik.api.SpanUpdate;
import com.comet.opik.api.filter.FiltersFactory;
Expand All @@ -27,6 +28,7 @@
import jakarta.validation.Valid;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.DefaultValue;
Expand All @@ -47,6 +49,7 @@
import lombok.extern.slf4j.Slf4j;

import java.util.UUID;
import java.util.stream.Collectors;

import static com.comet.opik.utils.AsyncUtils.setRequestContext;
import static com.comet.opik.utils.ValidationUtils.validateProjectNameAndProjectId;
Expand Down Expand Up @@ -132,6 +135,30 @@ public Response create(
return Response.created(uri).build();
}

@POST
@Path("/batch")
@Operation(operationId = "createSpans", summary = "Create spans", description = "Create spans", responses = {
@ApiResponse(responseCode = "204", description = "No Content")})
public Response createSpans(
@RequestBody(content = @Content(schema = @Schema(implementation = SpanBatch.class))) @JsonView(Span.View.Write.class) @NotNull @Valid SpanBatch spans) {

spans.spans()
.stream()
.filter(span -> span.id() != null) // Filter out spans with null IDs
.collect(Collectors.groupingBy(Span::id))
.forEach((id, spanGroup) -> {
if (spanGroup.size() > 1) {
throw new ClientErrorException("Duplicate span id '%s'".formatted(id), 422);
}
});

spanService.create(spans)
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();

return Response.noContent().build();
}

@PATCH
@Path("{id}")
@Operation(operationId = "updateSpan", summary = "Update span by id", description = "Update span by id", responses = {
Expand Down
120 changes: 119 additions & 1 deletion apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import com.comet.opik.domain.filter.FilterQueryBuilder;
import com.comet.opik.domain.filter.FilterStrategy;
import com.comet.opik.utils.JsonUtils;
import com.comet.opik.utils.TemplateUtils;
import com.google.common.base.Preconditions;
import com.newrelic.api.agent.Segment;
import com.newrelic.api.agent.Trace;
import io.r2dbc.spi.Connection;
Expand Down Expand Up @@ -36,15 +38,60 @@
import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux;
import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToMono;
import static com.comet.opik.domain.FeedbackScoreDAO.EntityType;
import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.endSegment;
import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.startSegment;
import static com.comet.opik.utils.AsyncUtils.makeFluxContextAware;
import static com.comet.opik.utils.AsyncUtils.makeMonoContextAware;
import static com.comet.opik.utils.TemplateUtils.getQueryItemPlaceHolder;

@Singleton
@RequiredArgsConstructor(onConstructor_ = @Inject)
@Slf4j
class SpanDAO {

private static final String BULK_INSERT = """
INSERT INTO spans(
id,
project_id,
workspace_id,
trace_id,
parent_span_id,
name,
type,
start_time,
end_time,
input,
output,
metadata,
tags,
usage,
created_by,
last_updated_by
) VALUES
<items:{item |
(
:id<item.index>,
:project_id<item.index>,
:workspace_id,
:trace_id<item.index>,
:parent_span_id<item.index>,
:name<item.index>,
:type<item.index>,
parseDateTime64BestEffort(:start_time<item.index>, 9),
if(:end_time<item.index> IS NULL, NULL, parseDateTime64BestEffort(:end_time<item.index>, 9)),
:input<item.index>,
:output<item.index>,
:metadata<item.index>,
:tags<item.index>,
mapFromArrays(:usage_keys<item.index>, :usage_values<item.index>),
:created_by<item.index>,
:last_updated_by<item.index>
)
<if(item.hasNext)>,<endif>
}>
;
""";

/**
* This query handles the insertion of a new span into the database in two cases:
* 1. When the span does not exist in the database.
Expand Down Expand Up @@ -444,6 +491,78 @@ public Mono<Void> insert(@NonNull Span span) {
.then();
}

@Trace(dispatcher = true)
public Mono<Long> batchInsert(@NonNull List<Span> spans) {

Preconditions.checkArgument(!spans.isEmpty(), "Spans list must not be empty");

return Mono.from(connectionFactory.create())
.flatMapMany(connection -> insert(spans, connection))
.flatMap(Result::getRowsUpdated)
.reduce(0L, Long::sum);
}

private Publisher<? extends Result> insert(List<Span> spans, Connection connection) {

return makeMonoContextAware((userName, workspaceName, workspaceId) -> {
List<TemplateUtils.QueryItem> queryItems = getQueryItemPlaceHolder(spans.size());

var template = new ST(BULK_INSERT)
.add("items", queryItems);

Statement statement = connection.createStatement(template.render());

int i = 0;
for (Span span : spans) {

statement.bind("id" + i, span.id())
.bind("project_id" + i, span.projectId())
.bind("trace_id" + i, span.traceId())
.bind("name" + i, span.name())
.bind("type" + i, span.type().toString())
.bind("start_time" + i, span.startTime().toString())
.bind("parent_span_id" + i, span.parentSpanId() != null ? span.parentSpanId() : "")
.bind("input" + i, span.input() != null ? span.input().toString() : "")
.bind("output" + i, span.output() != null ? span.output().toString() : "")
.bind("metadata" + i, span.metadata() != null ? span.metadata().toString() : "")
.bind("tags" + i, span.tags() != null ? span.tags().toArray(String[]::new) : new String[]{})
.bind("created_by" + i, userName)
.bind("last_updated_by" + i, userName);

if (span.endTime() != null) {
statement.bind("end_time" + i, span.endTime().toString());
} else {
statement.bindNull("end_time" + i, String.class);
}

if (span.usage() != null) {
Stream.Builder<String> keys = Stream.builder();
Stream.Builder<Integer> values = Stream.builder();

span.usage().forEach((key, value) -> {
keys.add(key);
values.add(value);
});

statement.bind("usage_keys" + i, keys.build().toArray(String[]::new));
statement.bind("usage_values" + i, values.build().toArray(Integer[]::new));
} else {
statement.bind("usage_keys" + i, new String[]{});
statement.bind("usage_values" + i, new Integer[]{});
}

i++;
}

statement.bind("workspace_id", workspaceId);

Segment segment = startSegment("spans", "Clickhouse", "batch_insert");

return Mono.from(statement.execute())
.doFinally(signalType -> endSegment(segment));
});
}

private Publisher<? extends Result> insert(Span span, Connection connection) {
var template = newInsertTemplate(span);
var statement = connection.createStatement(template.render())
Expand Down Expand Up @@ -788,5 +907,4 @@ public Mono<List<WorkspaceAndResourceId>> getSpanWorkspace(@NonNull Set<UUID> sp
row.get("id", UUID.class))))
.collectList();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.clickhouse.client.ClickHouseException;
import com.comet.opik.api.Project;
import com.comet.opik.api.Span;
import com.comet.opik.api.SpanBatch;
import com.comet.opik.api.SpanSearchCriteria;
import com.comet.opik.api.SpanUpdate;
import com.comet.opik.api.error.EntityAlreadyExistsException;
Expand All @@ -11,21 +12,26 @@
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.redis.LockService;
import com.comet.opik.utils.WorkspaceUtils;
import com.google.common.base.Preconditions;
import com.newrelic.api.agent.Trace;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.NotFoundException;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.comet.opik.utils.AsyncUtils.makeMonoContextAware;

Expand All @@ -34,10 +40,10 @@
@Slf4j
public class SpanService {

public static final String PROJECT_NAME_AND_WORKSPACE_MISMATCH = "Project name and workspace name do not match the existing span";
public static final String PARENT_SPAN_IS_MISMATCH = "parent_span_id does not match the existing span";
public static final String TRACE_ID_MISMATCH = "trace_id does not match the existing span";
public static final String SPAN_KEY = "Span";
public static final String PROJECT_NAME_MISMATCH = "Project name and workspace name do not match the existing span";

private final @NonNull SpanDAO spanDAO;
private final @NonNull ProjectService projectService;
Expand Down Expand Up @@ -116,7 +122,7 @@ private Mono<UUID> insertSpan(Span span, Project project, UUID id, Span existing
}

if (!project.id().equals(existingSpan.projectId())) {
return failWithConflict(PROJECT_NAME_AND_WORKSPACE_MISMATCH);
return failWithConflict(PROJECT_NAME_MISMATCH);
}

if (!Objects.equals(span.parentSpanId(), existingSpan.parentSpanId())) {
Expand Down Expand Up @@ -191,7 +197,7 @@ private <T> Mono<T> handleSpanDBError(Throwable ex) {
&& (ex.getMessage().contains("_CAST(project_id, FixedString(36))")
|| ex.getMessage()
.contains(", CAST(leftPad(workspace_id, 40, '*'), 'FixedString(19)') ::"))) {
return failWithConflict(PROJECT_NAME_AND_WORKSPACE_MISMATCH);
return failWithConflict(PROJECT_NAME_MISMATCH);
}

if (ex instanceof ClickHouseException
Expand All @@ -214,7 +220,7 @@ private <T> Mono<T> handleSpanDBError(Throwable ex) {

private Mono<Long> updateOrFail(SpanUpdate spanUpdate, UUID id, Span existingSpan, Project project) {
if (!project.id().equals(existingSpan.projectId())) {
return failWithConflict(PROJECT_NAME_AND_WORKSPACE_MISMATCH);
return failWithConflict(PROJECT_NAME_MISMATCH);
}

if (!Objects.equals(existingSpan.parentSpanId(), spanUpdate.parentSpanId())) {
Expand Down Expand Up @@ -244,4 +250,47 @@ public Mono<Boolean> validateSpanWorkspace(@NonNull String workspaceId, @NonNull
return spanDAO.getSpanWorkspace(spanIds)
.map(spanWorkspace -> spanWorkspace.stream().allMatch(span -> workspaceId.equals(span.workspaceId())));
}

@Trace(dispatcher = true)
public Mono<Long> create(@NonNull SpanBatch batch) {

Preconditions.checkArgument(!batch.spans().isEmpty(), "Batch spans must not be empty");

List<String> projectNames = batch.spans()
.stream()
.map(Span::projectName)
.distinct()
.toList();

Mono<List<Span>> resolveProjects = Flux.fromIterable(projectNames)
.flatMap(this::resolveProject)
.collectList()
.map(projects -> bindSpanToProjectAndId(batch, projects))
.subscribeOn(Schedulers.boundedElastic());

return resolveProjects
.flatMap(spanDAO::batchInsert);
}

private List<Span> bindSpanToProjectAndId(SpanBatch batch, List<Project> projects) {
Map<String, Project> projectPerName = projects.stream()
.collect(Collectors.toMap(Project::name, Function.identity()));

return batch.spans()
.stream()
.map(span -> {
String projectName = WorkspaceUtils.getProjectName(span.projectName());
Project project = projectPerName.get(projectName);

UUID id = span.id() == null ? idGenerator.generateId() : span.id();
IdGenerator.validateVersion(id, SPAN_KEY);

return span.toBuilder().id(id).projectId(project.id()).build();
})
.toList();
}

private Mono<Project> resolveProject(String projectName) {
return getOrCreateProject(WorkspaceUtils.getProjectName(projectName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import com.newrelic.api.agent.DatastoreParameters;
import com.newrelic.api.agent.NewRelic;
import com.newrelic.api.agent.Segment;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;
import reactor.core.scheduler.Schedulers;

@Slf4j
@UtilityClass
public class InstrumentAsyncUtils {

public static Segment startSegment(String segmentName, String product, String operationName) {
Expand Down
Loading

0 comments on commit 7b6527c

Please sign in to comment.