Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPIK-75 Batch traces creation endpoint #229

Merged
merged 17 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonView;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Size;

import java.util.List;

public record TraceBatch(@NotNull @Size(min = 1, max = 1000) @JsonView( {
Trace.View.Write.class}) @Valid List<Trace> traces){
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.comet.opik.api.FeedbackScore;
import com.comet.opik.api.FeedbackScoreBatch;
import com.comet.opik.api.Trace;
import com.comet.opik.api.TraceBatch;
import com.comet.opik.api.TraceSearchCriteria;
import com.comet.opik.api.TraceUpdate;
import com.comet.opik.api.filter.FiltersFactory;
Expand All @@ -25,6 +26,7 @@
import jakarta.validation.Valid;
import jakarta.validation.constraints.Min;
import jakarta.validation.constraints.NotNull;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.DefaultValue;
Expand All @@ -45,6 +47,7 @@
import lombok.extern.slf4j.Slf4j;

import java.util.UUID;
import java.util.stream.Collectors;

import static com.comet.opik.utils.AsyncUtils.setRequestContext;
import static com.comet.opik.utils.ValidationUtils.validateProjectNameAndProjectId;
Expand Down Expand Up @@ -115,6 +118,30 @@ public Response create(
return Response.created(uri).build();
}

@POST
@Path("/batch")
@Operation(operationId = "createTraces", summary = "Create traces", description = "Create traces", responses = {
@ApiResponse(responseCode = "204", description = "No Content")})
public Response createSpans(
@RequestBody(content = @Content(schema = @Schema(implementation = TraceBatch.class))) @JsonView(Trace.View.Write.class) @NotNull @Valid TraceBatch traces) {

traces.traces()
.stream()
.filter(trace -> trace.id() != null) // Filter out spans with null IDs
.collect(Collectors.groupingBy(Trace::id))
.forEach((id, traceGroup) -> {
if (traceGroup.size() > 1) {
throw new ClientErrorException("Duplicate trace id '%s'".formatted(id), 422);
}
});

service.create(traces)
.contextWrite(ctx -> setRequestContext(ctx, requestContext))
.block();

return Response.noContent().build();
}

@PATCH
@Path("{id}")
@Operation(operationId = "updateTrace", summary = "Update trace by id", description = "Update trace by id", responses = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

@ImplementedBy(FeedbackScoreServiceImpl.class)
public interface FeedbackScoreService {
Flux<FeedbackScore> getScores(EntityType entityType, UUID entityId);

Mono<Void> scoreTrace(UUID traceId, FeedbackScore score);
Mono<Void> scoreSpan(UUID spanId, FeedbackScore score);
Expand Down Expand Up @@ -66,12 +65,6 @@ class FeedbackScoreServiceImpl implements FeedbackScoreService {
record ProjectDto(Project project, List<FeedbackScoreBatchItem> scores) {
}

@Override
public Flux<FeedbackScore> getScores(@NonNull EntityType entityType, @NonNull UUID entityId) {
return asyncTemplate.nonTransaction(connection -> dao.getScores(entityType, List.of(entityId), connection))
.flatMapIterable(entityIdToFeedbackScoresMap -> entityIdToFeedbackScoresMap.get(entityId));
}

@Override
public Mono<Void> scoreTrace(@NonNull UUID traceId, @NonNull FeedbackScore score) {
return lockService.executeWithLock(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import com.comet.opik.domain.filter.FilterQueryBuilder;
import com.comet.opik.domain.filter.FilterStrategy;
import com.comet.opik.utils.JsonUtils;
import com.comet.opik.utils.TemplateUtils;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Preconditions;
import com.google.inject.ImplementedBy;
import com.newrelic.api.agent.Segment;
import io.r2dbc.spi.Connection;
Expand Down Expand Up @@ -38,6 +41,7 @@
import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.startSegment;
import static com.comet.opik.utils.AsyncUtils.makeFluxContextAware;
import static com.comet.opik.utils.AsyncUtils.makeMonoContextAware;
import static com.comet.opik.utils.TemplateUtils.getQueryItemPlaceHolder;

@ImplementedBy(TraceDAOImpl.class)
interface TraceDAO {
Expand All @@ -57,13 +61,49 @@ Mono<Void> partialInsert(UUID projectId, TraceUpdate traceUpdate, UUID traceId,

Mono<List<WorkspaceAndResourceId>> getTraceWorkspace(Set<UUID> traceIds, Connection connection);

Mono<Long> batchInsert(List<Trace> traces, Connection connection);
}

@Slf4j
@Singleton
@RequiredArgsConstructor(onConstructor_ = @Inject)
class TraceDAOImpl implements TraceDAO {

private static final String BATCH_INSERT = """
INSERT INTO traces(
id,
project_id,
workspace_id,
name,
start_time,
end_time,
input,
output,
metadata,
tags,
created_by,
last_updated_by
) VALUES
<items:{item |
(
:id<item.index>,
:project_id<item.index>,
:workspace_id,
:name<item.index>,
parseDateTime64BestEffort(:start_time<item.index>, 9),
if(:end_time<item.index> IS NULL, NULL, parseDateTime64BestEffort(:end_time<item.index>, 9)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:input<item.index>,
:output<item.index>,
:metadata<item.index>,
:tags<item.index>,
:user_name,
:user_name
)
<if(item.hasNext)>,<endif>
}>
;
""";

/**
* This query handles the insertion of a new trace into the database in two cases:
* 1. When the trace does not exist in the database.
Expand Down Expand Up @@ -695,4 +735,61 @@ public Mono<List<WorkspaceAndResourceId>> getTraceWorkspace(
.collectList();
}

@Override
public Mono<Long> batchInsert(@NonNull List<Trace> traces, @NonNull Connection connection) {

Preconditions.checkArgument(!traces.isEmpty(), "traces must not be empty");
Comment on lines +739 to +741
Copy link
Collaborator

@andrescrz andrescrz Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: non-null and empty can be done in one go by removing the non-null annotation and using CollectionUtils.IsNotEmpty.


return Mono.from(insert(traces, connection))
.flatMapMany(Result::getRowsUpdated)
.reduce(0L, Long::sum);

}

private Publisher<? extends Result> insert(List<Trace> traces, Connection connection) {

return makeMonoContextAware((userName, workspaceName, workspaceId) -> {
List<TemplateUtils.QueryItem> queryItems = getQueryItemPlaceHolder(traces.size());

var template = new ST(BATCH_INSERT)
.add("items", queryItems);

Statement statement = connection.createStatement(template.render());

int i = 0;
for (Trace trace : traces) {

statement.bind("id" + i, trace.id())
.bind("project_id" + i, trace.projectId())
.bind("name" + i, trace.name())
.bind("start_time" + i, trace.startTime().toString())
.bind("input" + i, getOrDefault(trace.input()))
.bind("output" + i, getOrDefault(trace.output()))
.bind("metadata" + i, getOrDefault(trace.metadata()))
.bind("tags" + i, trace.tags() != null ? trace.tags().toArray(String[]::new) : new String[]{});

if (trace.endTime() != null) {
statement.bind("end_time" + i, trace.endTime().toString());
} else {
statement.bindNull("end_time" + i, String.class);
}

i++;
}

statement
.bind("workspace_id", workspaceId)
.bind("user_name", userName);

Segment segment = startSegment("traces", "Clickhouse", "batch_insert");

return Mono.from(statement.execute())
.doFinally(signalType -> endSegment(segment));
});
}

private String getOrDefault(JsonNode value) {
return value != null ? value.toString() : "";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.clickhouse.client.ClickHouseException;
import com.comet.opik.api.Project;
import com.comet.opik.api.Trace;
import com.comet.opik.api.TraceBatch;
import com.comet.opik.api.TraceSearchCriteria;
import com.comet.opik.api.TraceUpdate;
import com.comet.opik.api.error.EntityAlreadyExistsException;
Expand All @@ -13,20 +14,25 @@
import com.comet.opik.infrastructure.redis.LockService;
import com.comet.opik.utils.AsyncUtils;
import com.comet.opik.utils.WorkspaceUtils;
import com.google.common.base.Preconditions;
import com.google.inject.ImplementedBy;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.NotFoundException;
import jakarta.ws.rs.core.Response;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.comet.opik.domain.FeedbackScoreDAO.EntityType;

Expand All @@ -35,6 +41,8 @@ public interface TraceService {

Mono<UUID> create(Trace trace);

Mono<Long> create(TraceBatch batch);

Mono<Void> update(TraceUpdate trace, UUID id);

Mono<Trace> get(UUID id);
Expand Down Expand Up @@ -77,6 +85,49 @@ public Mono<UUID> create(@NonNull Trace trace) {
Mono.defer(() -> insertTrace(trace, project, id))));
}

@com.newrelic.api.agent.Trace(dispatcher = true)
public Mono<Long> create(TraceBatch batch) {

Preconditions.checkArgument(!batch.traces().isEmpty(), "Batch traces cannot be empty");
Comment on lines +89 to +91
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor risk of NPE: better use CollectionUtils.isNotEmpty.


List<String> projectNames = batch.traces()
.stream()
.map(Trace::projectName)
.distinct()
.toList();

Mono<List<Trace>> resolveProjects = Flux.fromIterable(projectNames)
.flatMap(this::resolveProject)
.collectList()
.map(projects -> bindTraceToProjectAndId(batch, projects))
.subscribeOn(Schedulers.boundedElastic());

return resolveProjects
.flatMap(traces -> template.nonTransaction(connection -> dao.batchInsert(traces, connection)));
}

private List<Trace> bindTraceToProjectAndId(TraceBatch batch, List<Project> projects) {
Map<String, Project> projectPerName = projects.stream()
.collect(Collectors.toMap(Project::name, Function.identity()));

return batch.traces()
.stream()
.map(trace -> {
String projectName = WorkspaceUtils.getProjectName(trace.projectName());
Project project = projectPerName.get(projectName);

UUID id = trace.id() == null ? idGenerator.generateId() : trace.id();
IdGenerator.validateVersion(id, TRACE_KEY);

return trace.toBuilder().id(id).projectId(project.id()).build();
})
.toList();
}

private Mono<Project> resolveProject(String projectName) {
return getOrCreateProject(WorkspaceUtils.getProjectName(projectName));
}
Comment on lines +127 to +129
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: kind of pointless method, it can be removed.


private Mono<UUID> insertTrace(Trace newTrace, Project project, UUID id) {
//TODO: refactor to implement proper conflict resolution
return template.nonTransaction(connection -> dao.findById(id, connection))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package com.comet.opik.api.resources.utils;

import lombok.experimental.UtilityClass;

import java.net.URI;
import java.util.UUID;

@UtilityClass
public class TestUtils {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: better use a more specific name for this class. Having said that, this is a nice change.


public static UUID getIdFromLocation(URI location) {
return UUID.fromString(location.getPath().substring(location.getPath().lastIndexOf('/') + 1));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.comet.opik.api.resources.utils.MySQLContainerUtils;
import com.comet.opik.api.resources.utils.RedisContainerUtils;
import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils;
import com.comet.opik.api.resources.utils.TestUtils;
import com.comet.opik.api.resources.utils.WireMockUtils;
import com.comet.opik.domain.FeedbackScoreMapper;
import com.comet.opik.podam.PodamFactoryUtils;
Expand Down Expand Up @@ -195,8 +196,7 @@ private UUID createAndAssert(Dataset dataset, String apiKey, String workspaceNam
assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201);
assertThat(actualResponse.hasEntity()).isFalse();

var id = UUID.fromString(actualResponse.getHeaderString("Location")
.substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1));
var id = TestUtils.getIdFromLocation(actualResponse.getLocation());

assertThat(id).isNotNull();
assertThat(id.version()).isEqualTo(7);
Expand Down Expand Up @@ -2532,8 +2532,7 @@ private UUID createTrace(Trace trace, String apiKey, String workspaceName) {
.post(Entity.entity(trace, MediaType.APPLICATION_JSON_TYPE))) {
assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201);

return UUID.fromString(actualResponse.getHeaderString("Location")
.substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1));
return TestUtils.getIdFromLocation(actualResponse.getLocation());
}
}

Expand All @@ -2545,8 +2544,7 @@ private UUID createSpan(Span span, String apiKey, String workspaceName) {
.post(Entity.entity(span, MediaType.APPLICATION_JSON_TYPE))) {
assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201);

return UUID.fromString(actualResponse.getHeaderString("Location")
.substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1));
return TestUtils.getIdFromLocation(actualResponse.getLocation());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.comet.opik.api.resources.utils.MySQLContainerUtils;
import com.comet.opik.api.resources.utils.RedisContainerUtils;
import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils;
import com.comet.opik.api.resources.utils.TestUtils;
import com.comet.opik.api.resources.utils.WireMockUtils;
import com.comet.opik.domain.FeedbackScoreMapper;
import com.comet.opik.podam.PodamFactoryUtils;
Expand Down Expand Up @@ -1536,8 +1537,7 @@ private UUID createAndAssert(Experiment expectedExperiment, String apiKey, Strin

assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201);

var path = actualResponse.getLocation().getPath();
var actualId = UUID.fromString(path.substring(path.lastIndexOf('/') + 1));
var actualId = TestUtils.getIdFromLocation(actualResponse.getLocation());

assertThat(actualResponse.hasEntity()).isFalse();

Expand Down
Loading