diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java new file mode 100644 index 0000000000..57435ebf26 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/SpanBatch.java @@ -0,0 +1,12 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonView; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; + +import java.util.List; + +public record SpanBatch(@NotNull @Size(min = 1, max = 1000) @JsonView( { + Span.View.Write.class}) @Valid List spans){ +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java new file mode 100644 index 0000000000..ac5c164940 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/TraceBatch.java @@ -0,0 +1,12 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonView; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.Size; + +import java.util.List; + +public record TraceBatch(@NotNull @Size(min = 1, max = 1000) @JsonView( { + Trace.View.Write.class}) @Valid List traces){ +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java index 948ad91a41..968bd9d870 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/SpansResource.java @@ -5,6 +5,7 @@ import com.comet.opik.api.FeedbackScore; import com.comet.opik.api.FeedbackScoreBatch; import com.comet.opik.api.Span; +import com.comet.opik.api.SpanBatch; import com.comet.opik.api.SpanSearchCriteria; import com.comet.opik.api.SpanUpdate; import com.comet.opik.api.filter.FiltersFactory; @@ -27,6 +28,7 @@ import jakarta.validation.Valid; import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; +import jakarta.ws.rs.ClientErrorException; import jakarta.ws.rs.Consumes; import jakarta.ws.rs.DELETE; import jakarta.ws.rs.DefaultValue; @@ -47,6 +49,7 @@ import lombok.extern.slf4j.Slf4j; import java.util.UUID; +import java.util.stream.Collectors; import static com.comet.opik.api.Span.SpanPage; import static com.comet.opik.api.Span.View; @@ -142,6 +145,30 @@ public Response create( return Response.created(uri).build(); } + @POST + @Path("/batch") + @Operation(operationId = "createSpans", summary = "Create spans", description = "Create spans", responses = { + @ApiResponse(responseCode = "204", description = "No Content")}) + public Response createSpans( + @RequestBody(content = @Content(schema = @Schema(implementation = SpanBatch.class))) @JsonView(Span.View.Write.class) @NotNull @Valid SpanBatch spans) { + + spans.spans() + .stream() + .filter(span -> span.id() != null) // Filter out spans with null IDs + .collect(Collectors.groupingBy(Span::id)) + .forEach((id, spanGroup) -> { + if (spanGroup.size() > 1) { + throw new ClientErrorException("Duplicate span id '%s'".formatted(id), 422); + } + }); + + spanService.create(spans) + .contextWrite(ctx -> setRequestContext(ctx, requestContext)) + .block(); + + return Response.noContent().build(); + } + @PATCH @Path("{id}") @Operation(operationId = "updateSpan", summary = "Update span by id", description = "Update span by id", responses = { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TraceResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TraceResource.java index 5e9bf57d00..31bf42715c 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TraceResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/TraceResource.java @@ -6,6 +6,7 @@ import com.comet.opik.api.FeedbackScoreBatch; import com.comet.opik.api.Trace; import com.comet.opik.api.Trace.TracePage; +import com.comet.opik.api.TraceBatch; import com.comet.opik.api.TraceSearchCriteria; import com.comet.opik.api.TraceUpdate; import com.comet.opik.api.filter.FiltersFactory; @@ -26,6 +27,7 @@ import jakarta.validation.Valid; import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; +import jakarta.ws.rs.ClientErrorException; import jakarta.ws.rs.Consumes; import jakarta.ws.rs.DELETE; import jakarta.ws.rs.DefaultValue; @@ -46,6 +48,7 @@ import lombok.extern.slf4j.Slf4j; import java.util.UUID; +import java.util.stream.Collectors; import static com.comet.opik.utils.AsyncUtils.setRequestContext; import static com.comet.opik.utils.ValidationUtils.validateProjectNameAndProjectId; @@ -142,6 +145,30 @@ public Response create( return Response.created(uri).build(); } + @POST + @Path("/batch") + @Operation(operationId = "createTraces", summary = "Create traces", description = "Create traces", responses = { + @ApiResponse(responseCode = "204", description = "No Content")}) + public Response createSpans( + @RequestBody(content = @Content(schema = @Schema(implementation = TraceBatch.class))) @JsonView(Trace.View.Write.class) @NotNull @Valid TraceBatch traces) { + + traces.traces() + .stream() + .filter(trace -> trace.id() != null) // Filter out spans with null IDs + .collect(Collectors.groupingBy(Trace::id)) + .forEach((id, traceGroup) -> { + if (traceGroup.size() > 1) { + throw new ClientErrorException("Duplicate trace id '%s'".formatted(id), 422); + } + }); + + service.create(traces) + .contextWrite(ctx -> setRequestContext(ctx, requestContext)) + .block(); + + return Response.noContent().build(); + } + @PATCH @Path("{id}") @Operation(operationId = "updateTrace", summary = "Update trace by id", description = "Update trace by id", responses = { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java index 9dad98d5c8..3c2569ba61 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/FeedbackScoreService.java @@ -37,7 +37,6 @@ @ImplementedBy(FeedbackScoreServiceImpl.class) public interface FeedbackScoreService { - Flux getScores(EntityType entityType, UUID entityId); Mono scoreTrace(UUID traceId, FeedbackScore score); Mono scoreSpan(UUID spanId, FeedbackScore score); @@ -66,12 +65,6 @@ class FeedbackScoreServiceImpl implements FeedbackScoreService { record ProjectDto(Project project, List scores) { } - @Override - public Flux getScores(@NonNull EntityType entityType, @NonNull UUID entityId) { - return asyncTemplate.nonTransaction(connection -> dao.getScores(entityType, List.of(entityId), connection)) - .flatMapIterable(entityIdToFeedbackScoresMap -> entityIdToFeedbackScoresMap.get(entityId)); - } - @Override public Mono scoreTrace(@NonNull UUID traceId, @NonNull FeedbackScore score) { return lockService.executeWithLock( diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java index b75a170307..0a6a9daa4e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java @@ -6,6 +6,8 @@ import com.comet.opik.domain.filter.FilterQueryBuilder; import com.comet.opik.domain.filter.FilterStrategy; import com.comet.opik.utils.JsonUtils; +import com.comet.opik.utils.TemplateUtils; +import com.google.common.base.Preconditions; import com.newrelic.api.agent.Segment; import com.newrelic.api.agent.Trace; import io.r2dbc.spi.Connection; @@ -36,15 +38,60 @@ import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux; import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToMono; import static com.comet.opik.domain.FeedbackScoreDAO.EntityType; +import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.endSegment; import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.startSegment; import static com.comet.opik.utils.AsyncUtils.makeFluxContextAware; import static com.comet.opik.utils.AsyncUtils.makeMonoContextAware; +import static com.comet.opik.utils.TemplateUtils.getQueryItemPlaceHolder; @Singleton @RequiredArgsConstructor(onConstructor_ = @Inject) @Slf4j class SpanDAO { + private static final String BULK_INSERT = """ + INSERT INTO spans( + id, + project_id, + workspace_id, + trace_id, + parent_span_id, + name, + type, + start_time, + end_time, + input, + output, + metadata, + tags, + usage, + created_by, + last_updated_by + ) VALUES + , + :project_id, + :workspace_id, + :trace_id, + :parent_span_id, + :name, + :type, + parseDateTime64BestEffort(:start_time, 9), + if(:end_time IS NULL, NULL, parseDateTime64BestEffort(:end_time, 9)), + :input, + :output, + :metadata, + :tags, + mapFromArrays(:usage_keys, :usage_values), + :created_by, + :last_updated_by + ) + , + }> + ; + """; + /** * This query handles the insertion of a new span into the database in two cases: * 1. When the span does not exist in the database. @@ -444,6 +491,78 @@ public Mono insert(@NonNull Span span) { .then(); } + @Trace(dispatcher = true) + public Mono batchInsert(@NonNull List spans) { + + Preconditions.checkArgument(!spans.isEmpty(), "Spans list must not be empty"); + + return Mono.from(connectionFactory.create()) + .flatMapMany(connection -> insert(spans, connection)) + .flatMap(Result::getRowsUpdated) + .reduce(0L, Long::sum); + } + + private Publisher insert(List spans, Connection connection) { + + return makeMonoContextAware((userName, workspaceName, workspaceId) -> { + List queryItems = getQueryItemPlaceHolder(spans.size()); + + var template = new ST(BULK_INSERT) + .add("items", queryItems); + + Statement statement = connection.createStatement(template.render()); + + int i = 0; + for (Span span : spans) { + + statement.bind("id" + i, span.id()) + .bind("project_id" + i, span.projectId()) + .bind("trace_id" + i, span.traceId()) + .bind("name" + i, span.name()) + .bind("type" + i, span.type().toString()) + .bind("start_time" + i, span.startTime().toString()) + .bind("parent_span_id" + i, span.parentSpanId() != null ? span.parentSpanId() : "") + .bind("input" + i, span.input() != null ? span.input().toString() : "") + .bind("output" + i, span.output() != null ? span.output().toString() : "") + .bind("metadata" + i, span.metadata() != null ? span.metadata().toString() : "") + .bind("tags" + i, span.tags() != null ? span.tags().toArray(String[]::new) : new String[]{}) + .bind("created_by" + i, userName) + .bind("last_updated_by" + i, userName); + + if (span.endTime() != null) { + statement.bind("end_time" + i, span.endTime().toString()); + } else { + statement.bindNull("end_time" + i, String.class); + } + + if (span.usage() != null) { + Stream.Builder keys = Stream.builder(); + Stream.Builder values = Stream.builder(); + + span.usage().forEach((key, value) -> { + keys.add(key); + values.add(value); + }); + + statement.bind("usage_keys" + i, keys.build().toArray(String[]::new)); + statement.bind("usage_values" + i, values.build().toArray(Integer[]::new)); + } else { + statement.bind("usage_keys" + i, new String[]{}); + statement.bind("usage_values" + i, new Integer[]{}); + } + + i++; + } + + statement.bind("workspace_id", workspaceId); + + Segment segment = startSegment("spans", "Clickhouse", "batch_insert"); + + return Mono.from(statement.execute()) + .doFinally(signalType -> endSegment(segment)); + }); + } + private Publisher insert(Span span, Connection connection) { var template = newInsertTemplate(span); var statement = connection.createStatement(template.render()) @@ -788,5 +907,4 @@ public Mono> getSpanWorkspace(@NonNull Set sp row.get("id", UUID.class)))) .collectList(); } - } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java index a27632794b..27caa0053f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/SpanService.java @@ -3,6 +3,7 @@ import com.clickhouse.client.ClickHouseException; import com.comet.opik.api.Project; import com.comet.opik.api.Span; +import com.comet.opik.api.SpanBatch; import com.comet.opik.api.SpanSearchCriteria; import com.comet.opik.api.SpanUpdate; import com.comet.opik.api.error.EntityAlreadyExistsException; @@ -11,6 +12,7 @@ import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.redis.LockService; import com.comet.opik.utils.WorkspaceUtils; +import com.google.common.base.Preconditions; import com.newrelic.api.agent.Trace; import jakarta.inject.Inject; import jakarta.inject.Singleton; @@ -18,14 +20,18 @@ import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import java.time.Instant; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Collectors; import static com.comet.opik.utils.AsyncUtils.makeMonoContextAware; @@ -34,10 +40,10 @@ @Slf4j public class SpanService { - public static final String PROJECT_NAME_AND_WORKSPACE_MISMATCH = "Project name and workspace name do not match the existing span"; public static final String PARENT_SPAN_IS_MISMATCH = "parent_span_id does not match the existing span"; public static final String TRACE_ID_MISMATCH = "trace_id does not match the existing span"; public static final String SPAN_KEY = "Span"; + public static final String PROJECT_NAME_MISMATCH = "Project name and workspace name do not match the existing span"; private final @NonNull SpanDAO spanDAO; private final @NonNull ProjectService projectService; @@ -116,7 +122,7 @@ private Mono insertSpan(Span span, Project project, UUID id, Span existing } if (!project.id().equals(existingSpan.projectId())) { - return failWithConflict(PROJECT_NAME_AND_WORKSPACE_MISMATCH); + return failWithConflict(PROJECT_NAME_MISMATCH); } if (!Objects.equals(span.parentSpanId(), existingSpan.parentSpanId())) { @@ -191,7 +197,7 @@ private Mono handleSpanDBError(Throwable ex) { && (ex.getMessage().contains("_CAST(project_id, FixedString(36))") || ex.getMessage() .contains(", CAST(leftPad(workspace_id, 40, '*'), 'FixedString(19)') ::"))) { - return failWithConflict(PROJECT_NAME_AND_WORKSPACE_MISMATCH); + return failWithConflict(PROJECT_NAME_MISMATCH); } if (ex instanceof ClickHouseException @@ -214,7 +220,7 @@ private Mono handleSpanDBError(Throwable ex) { private Mono updateOrFail(SpanUpdate spanUpdate, UUID id, Span existingSpan, Project project) { if (!project.id().equals(existingSpan.projectId())) { - return failWithConflict(PROJECT_NAME_AND_WORKSPACE_MISMATCH); + return failWithConflict(PROJECT_NAME_MISMATCH); } if (!Objects.equals(existingSpan.parentSpanId(), spanUpdate.parentSpanId())) { @@ -247,4 +253,47 @@ public Mono validateSpanWorkspace(@NonNull String workspaceId, @NonNull return spanDAO.getSpanWorkspace(spanIds) .map(spanWorkspace -> spanWorkspace.stream().allMatch(span -> workspaceId.equals(span.workspaceId()))); } + + @Trace(dispatcher = true) + public Mono create(@NonNull SpanBatch batch) { + + Preconditions.checkArgument(!batch.spans().isEmpty(), "Batch spans must not be empty"); + + List projectNames = batch.spans() + .stream() + .map(Span::projectName) + .distinct() + .toList(); + + Mono> resolveProjects = Flux.fromIterable(projectNames) + .flatMap(this::resolveProject) + .collectList() + .map(projects -> bindSpanToProjectAndId(batch, projects)) + .subscribeOn(Schedulers.boundedElastic()); + + return resolveProjects + .flatMap(spanDAO::batchInsert); + } + + private List bindSpanToProjectAndId(SpanBatch batch, List projects) { + Map projectPerName = projects.stream() + .collect(Collectors.toMap(Project::name, Function.identity())); + + return batch.spans() + .stream() + .map(span -> { + String projectName = WorkspaceUtils.getProjectName(span.projectName()); + Project project = projectPerName.get(projectName); + + UUID id = span.id() == null ? idGenerator.generateId() : span.id(); + IdGenerator.validateVersion(id, SPAN_KEY); + + return span.toBuilder().id(id).projectId(project.id()).build(); + }) + .toList(); + } + + private Mono resolveProject(String projectName) { + return getOrCreateProject(WorkspaceUtils.getProjectName(projectName)); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java index 256daa3ad3..5a054f3245 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceDAO.java @@ -6,6 +6,9 @@ import com.comet.opik.domain.filter.FilterQueryBuilder; import com.comet.opik.domain.filter.FilterStrategy; import com.comet.opik.utils.JsonUtils; +import com.comet.opik.utils.TemplateUtils; +import com.fasterxml.jackson.databind.JsonNode; +import com.google.common.base.Preconditions; import com.google.inject.ImplementedBy; import com.newrelic.api.agent.Segment; import io.r2dbc.spi.Connection; @@ -38,6 +41,7 @@ import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.startSegment; import static com.comet.opik.utils.AsyncUtils.makeFluxContextAware; import static com.comet.opik.utils.AsyncUtils.makeMonoContextAware; +import static com.comet.opik.utils.TemplateUtils.getQueryItemPlaceHolder; @ImplementedBy(TraceDAOImpl.class) interface TraceDAO { @@ -57,6 +61,7 @@ Mono partialInsert(UUID projectId, TraceUpdate traceUpdate, UUID traceId, Mono> getTraceWorkspace(Set traceIds, Connection connection); + Mono batchInsert(List traces, Connection connection); } @Slf4j @@ -64,6 +69,41 @@ Mono partialInsert(UUID projectId, TraceUpdate traceUpdate, UUID traceId, @RequiredArgsConstructor(onConstructor_ = @Inject) class TraceDAOImpl implements TraceDAO { + private static final String BATCH_INSERT = """ + INSERT INTO traces( + id, + project_id, + workspace_id, + name, + start_time, + end_time, + input, + output, + metadata, + tags, + created_by, + last_updated_by + ) VALUES + , + :project_id, + :workspace_id, + :name, + parseDateTime64BestEffort(:start_time, 9), + if(:end_time IS NULL, NULL, parseDateTime64BestEffort(:end_time, 9)), + :input, + :output, + :metadata, + :tags, + :user_name, + :user_name + ) + , + }> + ; + """; + /** * This query handles the insertion of a new trace into the database in two cases: * 1. When the trace does not exist in the database. @@ -695,4 +735,61 @@ public Mono> getTraceWorkspace( .collectList(); } + @Override + public Mono batchInsert(@NonNull List traces, @NonNull Connection connection) { + + Preconditions.checkArgument(!traces.isEmpty(), "traces must not be empty"); + + return Mono.from(insert(traces, connection)) + .flatMapMany(Result::getRowsUpdated) + .reduce(0L, Long::sum); + + } + + private Publisher insert(List traces, Connection connection) { + + return makeMonoContextAware((userName, workspaceName, workspaceId) -> { + List queryItems = getQueryItemPlaceHolder(traces.size()); + + var template = new ST(BATCH_INSERT) + .add("items", queryItems); + + Statement statement = connection.createStatement(template.render()); + + int i = 0; + for (Trace trace : traces) { + + statement.bind("id" + i, trace.id()) + .bind("project_id" + i, trace.projectId()) + .bind("name" + i, trace.name()) + .bind("start_time" + i, trace.startTime().toString()) + .bind("input" + i, getOrDefault(trace.input())) + .bind("output" + i, getOrDefault(trace.output())) + .bind("metadata" + i, getOrDefault(trace.metadata())) + .bind("tags" + i, trace.tags() != null ? trace.tags().toArray(String[]::new) : new String[]{}); + + if (trace.endTime() != null) { + statement.bind("end_time" + i, trace.endTime().toString()); + } else { + statement.bindNull("end_time" + i, String.class); + } + + i++; + } + + statement + .bind("workspace_id", workspaceId) + .bind("user_name", userName); + + Segment segment = startSegment("traces", "Clickhouse", "batch_insert"); + + return Mono.from(statement.execute()) + .doFinally(signalType -> endSegment(segment)); + }); + } + + private String getOrDefault(JsonNode value) { + return value != null ? value.toString() : ""; + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java index 405d2406e2..1b3b70a556 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java @@ -3,6 +3,7 @@ import com.clickhouse.client.ClickHouseException; import com.comet.opik.api.Project; import com.comet.opik.api.Trace; +import com.comet.opik.api.TraceBatch; import com.comet.opik.api.TraceSearchCriteria; import com.comet.opik.api.TraceUpdate; import com.comet.opik.api.error.EntityAlreadyExistsException; @@ -13,6 +14,7 @@ import com.comet.opik.infrastructure.redis.LockService; import com.comet.opik.utils.AsyncUtils; import com.comet.opik.utils.WorkspaceUtils; +import com.google.common.base.Preconditions; import com.google.inject.ImplementedBy; import jakarta.inject.Inject; import jakarta.inject.Singleton; @@ -21,13 +23,17 @@ import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import java.time.Instant; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; +import java.util.function.Function; +import java.util.stream.Collectors; import static com.comet.opik.domain.FeedbackScoreDAO.EntityType; @@ -36,6 +42,8 @@ public interface TraceService { Mono create(Trace trace); + Mono create(TraceBatch batch); + Mono update(TraceUpdate trace, UUID id); Mono get(UUID id); @@ -79,6 +87,49 @@ public Mono create(@NonNull Trace trace) { Mono.defer(() -> insertTrace(trace, project, id)))); } + @com.newrelic.api.agent.Trace(dispatcher = true) + public Mono create(TraceBatch batch) { + + Preconditions.checkArgument(!batch.traces().isEmpty(), "Batch traces cannot be empty"); + + List projectNames = batch.traces() + .stream() + .map(Trace::projectName) + .distinct() + .toList(); + + Mono> resolveProjects = Flux.fromIterable(projectNames) + .flatMap(this::resolveProject) + .collectList() + .map(projects -> bindTraceToProjectAndId(batch, projects)) + .subscribeOn(Schedulers.boundedElastic()); + + return resolveProjects + .flatMap(traces -> template.nonTransaction(connection -> dao.batchInsert(traces, connection))); + } + + private List bindTraceToProjectAndId(TraceBatch batch, List projects) { + Map projectPerName = projects.stream() + .collect(Collectors.toMap(Project::name, Function.identity())); + + return batch.traces() + .stream() + .map(trace -> { + String projectName = WorkspaceUtils.getProjectName(trace.projectName()); + Project project = projectPerName.get(projectName); + + UUID id = trace.id() == null ? idGenerator.generateId() : trace.id(); + IdGenerator.validateVersion(id, TRACE_KEY); + + return trace.toBuilder().id(id).projectId(project.id()).build(); + }) + .toList(); + } + + private Mono resolveProject(String projectName) { + return getOrCreateProject(WorkspaceUtils.getProjectName(projectName)); + } + private Mono insertTrace(Trace newTrace, Project project, UUID id) { //TODO: refactor to implement proper conflict resolution return template.nonTransaction(connection -> dao.findById(id, connection)) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/instrumentation/InstrumentAsyncUtils.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/instrumentation/InstrumentAsyncUtils.java index cc6f81b4df..0ba6e20572 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/instrumentation/InstrumentAsyncUtils.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/instrumentation/InstrumentAsyncUtils.java @@ -3,10 +3,12 @@ import com.newrelic.api.agent.DatastoreParameters; import com.newrelic.api.agent.NewRelic; import com.newrelic.api.agent.Segment; +import lombok.experimental.UtilityClass; import lombok.extern.slf4j.Slf4j; import reactor.core.scheduler.Schedulers; @Slf4j +@UtilityClass public class InstrumentAsyncUtils { public static Segment startSegment(String segmentName, String product, String operationName) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java index ded698ca27..5ef02eac86 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/redis/RedissonLockService.java @@ -21,14 +21,9 @@ class RedissonLockService implements LockService { private final @NonNull DistributedLockConfig distributedLockConfig; @Override - public Mono executeWithLock(Lock lock, Mono action) { + public Mono executeWithLock(@NonNull Lock lock, @NonNull Mono action) { - RPermitExpirableSemaphoreReactive semaphore = redisClient.getPermitExpirableSemaphore( - CommonOptions - .name(lock.key()) - .timeout(Duration.ofMillis(distributedLockConfig.getLockTimeoutMS())) - .retryInterval(Duration.ofMillis(10)) - .retryAttempts(distributedLockConfig.getLockTimeoutMS() / 10)); + RPermitExpirableSemaphoreReactive semaphore = getSemaphore(lock, distributedLockConfig.getLockTimeoutMS()); log.debug("Trying to lock with {}", lock); @@ -43,6 +38,15 @@ public Mono executeWithLock(Lock lock, Mono action) { })); } + private RPermitExpirableSemaphoreReactive getSemaphore(Lock lock, int lockTimeoutMS) { + return redisClient.getPermitExpirableSemaphore( + CommonOptions + .name(lock.key()) + .timeout(Duration.ofMillis(lockTimeoutMS)) + .retryInterval(Duration.ofMillis(10)) + .retryAttempts(lockTimeoutMS / 10)); + } + private Mono runAction(Lock lock, Mono action, String locked) { if (locked != null) { log.debug("Lock {} acquired", lock); @@ -53,13 +57,8 @@ private Mono runAction(Lock lock, Mono action, String locked) { } @Override - public Flux executeWithLock(Lock lock, Flux stream) { - RPermitExpirableSemaphoreReactive semaphore = redisClient.getPermitExpirableSemaphore( - CommonOptions - .name(lock.key()) - .timeout(Duration.ofMillis(distributedLockConfig.getLockTimeoutMS())) - .retryInterval(Duration.ofMillis(10)) - .retryAttempts(distributedLockConfig.getLockTimeoutMS() / 10)); + public Flux executeWithLock(@NonNull Lock lock, @NonNull Flux stream) { + RPermitExpirableSemaphoreReactive semaphore = getSemaphore(lock, distributedLockConfig.getLockTimeoutMS()); return semaphore .trySetPermits(1) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestUtils.java new file mode 100644 index 0000000000..99d5fd1d33 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/TestUtils.java @@ -0,0 +1,14 @@ +package com.comet.opik.api.resources.utils; + +import lombok.experimental.UtilityClass; + +import java.net.URI; +import java.util.UUID; + +@UtilityClass +public class TestUtils { + + public static UUID getIdFromLocation(URI location) { + return UUID.fromString(location.getPath().substring(location.getPath().lastIndexOf('/') + 1)); + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index 79e6cb9a51..537d8ccc94 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -24,6 +24,7 @@ import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.domain.FeedbackScoreMapper; import com.comet.opik.podam.PodamFactoryUtils; @@ -195,8 +196,7 @@ private UUID createAndAssert(Dataset dataset, String apiKey, String workspaceNam assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); assertThat(actualResponse.hasEntity()).isFalse(); - var id = UUID.fromString(actualResponse.getHeaderString("Location") - .substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1)); + var id = TestUtils.getIdFromLocation(actualResponse.getLocation()); assertThat(id).isNotNull(); assertThat(id.version()).isEqualTo(7); @@ -2532,8 +2532,7 @@ private UUID createTrace(Trace trace, String apiKey, String workspaceName) { .post(Entity.entity(trace, MediaType.APPLICATION_JSON_TYPE))) { assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - return UUID.fromString(actualResponse.getHeaderString("Location") - .substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1)); + return TestUtils.getIdFromLocation(actualResponse.getLocation()); } } @@ -2545,8 +2544,7 @@ private UUID createSpan(Span span, String apiKey, String workspaceName) { .post(Entity.entity(span, MediaType.APPLICATION_JSON_TYPE))) { assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - return UUID.fromString(actualResponse.getHeaderString("Location") - .substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1)); + return TestUtils.getIdFromLocation(actualResponse.getLocation()); } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java index 3880b5eb96..2beeb2df2b 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ExperimentsResourceTest.java @@ -18,6 +18,7 @@ import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.domain.FeedbackScoreMapper; import com.comet.opik.podam.PodamFactoryUtils; @@ -1536,8 +1537,7 @@ private UUID createAndAssert(Experiment expectedExperiment, String apiKey, Strin assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - var path = actualResponse.getLocation().getPath(); - var actualId = UUID.fromString(path.substring(path.lastIndexOf('/') + 1)); + var actualId = TestUtils.getIdFromLocation(actualResponse.getLocation()); assertThat(actualResponse.hasEntity()).isFalse(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResourceTest.java index 27f543ce41..5390a9f49c 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/FeedbackDefinitionResourceTest.java @@ -8,6 +8,7 @@ import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.podam.PodamFactoryUtils; import com.fasterxml.uuid.Generators; @@ -135,8 +136,7 @@ private UUID create(final FeedbackDefinition feedback, String apiKey, String assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - return UUID.fromString(actualResponse.getLocation().getPath() - .substring(actualResponse.getLocation().getPath().lastIndexOf('/') + 1)); + return TestUtils.getIdFromLocation(actualResponse.getLocation()); } } @@ -848,9 +848,7 @@ void create() { assertThat(actualResponse.hasEntity()).isFalse(); assertThat(actualResponse.getHeaderString("Location")).matches(Pattern.compile(URL_PATTERN)); - id = UUID.fromString(actualResponse.getLocation().getPath() - .substring(actualResponse.getLocation().getPath().lastIndexOf('/') + 1)); - + id = TestUtils.getIdFromLocation(actualResponse.getLocation()); } var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java index 117fabcd7d..ea8b879d8a 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ProjectsResourceTest.java @@ -9,6 +9,7 @@ import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.podam.PodamFactoryUtils; import com.github.tomakehurst.wiremock.client.WireMock; @@ -134,8 +135,7 @@ private UUID createProject(Project project, String apiKey, String workspaceName) assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - return UUID.fromString(actualResponse.getLocation().getPath() - .substring(actualResponse.getLocation().getPath().lastIndexOf('/') + 1)); + return TestUtils.getIdFromLocation(actualResponse.getLocation()); } } @@ -820,8 +820,7 @@ void create() { assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); assertThat(actualResponse.hasEntity()).isFalse(); - id = UUID.fromString(actualResponse.getHeaderString("Location") - .substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1)); + id = TestUtils.getIdFromLocation(actualResponse.getLocation()); } assertProject(project.toBuilder().id(id) @@ -848,8 +847,7 @@ void create__whenWorkspaceNameIsSpecified__thenAcceptTheRequest() { assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); assertThat(actualResponse.hasEntity()).isFalse(); - id = UUID.fromString(actualResponse.getHeaderString("Location") - .substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1)); + id = TestUtils.getIdFromLocation(actualResponse.getLocation()); } @@ -932,8 +930,7 @@ void create__whenProjectsHaveSameNameButDifferentWorkspace__thenAcceptTheRequest assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); assertThat(actualResponse.hasEntity()).isFalse(); - id = UUID.fromString(actualResponse.getHeaderString("Location") - .substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1)); + id = TestUtils.getIdFromLocation(actualResponse.getLocation()); } var project2 = project1.toBuilder() @@ -950,8 +947,7 @@ void create__whenProjectsHaveSameNameButDifferentWorkspace__thenAcceptTheRequest assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); assertThat(actualResponse.hasEntity()).isFalse(); - id2 = UUID.fromString(actualResponse.getHeaderString("Location") - .substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1)); + id2 = TestUtils.getIdFromLocation(actualResponse.getLocation()); } assertProject(project1.toBuilder().id(id).build()); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansResourceTest.java index 39a92f3e5b..028c1c3c75 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/SpansResourceTest.java @@ -7,6 +7,7 @@ import com.comet.opik.api.Project; import com.comet.opik.api.ScoreSource; import com.comet.opik.api.Span; +import com.comet.opik.api.SpanBatch; import com.comet.opik.api.SpanUpdate; import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.api.filter.Field; @@ -21,6 +22,7 @@ import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.domain.SpanMapper; import com.comet.opik.domain.SpanType; @@ -160,7 +162,7 @@ private static void mockTargetWorkspace(String apiKey, String workspaceName, Str AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); } - private UUID getProjectId(ClientSupport client, String projectName, String workspaceName, String apiKey) { + private UUID getProjectId(String projectName, String workspaceName, String apiKey) { return client.target("%s/v1/private/projects".formatted(baseURI)) .queryParam("name", projectName) .request() @@ -175,6 +177,18 @@ private UUID getProjectId(ClientSupport client, String projectName, String works .id(); } + private UUID createProject(String projectName, String workspaceName, String apiKey) { + try (Response response = client.target("%s/v1/private/projects".formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(Project.builder().name(projectName).build()))) { + + assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(201); + return TestUtils.getIdFromLocation(response.getLocation()); + } + } + @Nested @DisplayName("Api Key Authentication:") @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -2958,122 +2972,124 @@ void getByProjectName__whenFilterInvalidValueOrKeyForFieldType__thenReturn400(Fi assertThat(actualError).isEqualTo(expectedError); } - private void getAndAssertPage( - String workspaceName, - String projectName, - List filters, - List spans, - List expectedSpans, - List unexpectedSpans, String apiKey) { - int page = 1; - int size = spans.size() + expectedSpans.size() + unexpectedSpans.size(); - getAndAssertPage( - workspaceName, - projectName, - null, - null, - null, - filters, - page, - size, - expectedSpans, - expectedSpans.size(), - unexpectedSpans, apiKey); + private List updateFeedbackScore(List feedbackScores, int index, double val) { + feedbackScores.set(index, feedbackScores.get(index).toBuilder() + .value(BigDecimal.valueOf(val)) + .build()); + return feedbackScores; } - private void getAndAssertPage( - String workspaceName, - String projectName, - UUID projectId, - UUID traceId, - SpanType type, - List filters, - int page, - int size, - List expectedSpans, - int expectedTotal, - List unexpectedSpans, String apiKey) { - try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) - .queryParam("page", page) - .queryParam("size", size) - .queryParam("project_name", projectName) - .queryParam("project_id", projectId) - .queryParam("trace_id", traceId) - .queryParam("type", type) - .queryParam("filters", toURLEncodedQueryParam(filters)) - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .get()) { - var actualPage = actualResponse.readEntity(Span.SpanPage.class); - var actualSpans = actualPage.content(); + private List updateFeedbackScore( + List destination, List source, int index) { + destination.set(index, source.get(index).toBuilder().build()); + return destination; + } + } - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + private void getAndAssertPage( + String workspaceName, + String projectName, + List filters, + List spans, + List expectedSpans, + List unexpectedSpans, String apiKey) { + int page = 1; + int size = spans.size() + expectedSpans.size() + unexpectedSpans.size(); + getAndAssertPage( + workspaceName, + projectName, + null, + null, + null, + filters, + page, + size, + expectedSpans, + expectedSpans.size(), + unexpectedSpans, apiKey); + } - assertThat(actualPage.page()).isEqualTo(page); - assertThat(actualPage.size()).isEqualTo(expectedSpans.size()); - assertThat(actualPage.total()).isEqualTo(expectedTotal); + private void getAndAssertPage( + String workspaceName, + String projectName, + UUID projectId, + UUID traceId, + SpanType type, + List filters, + int page, + int size, + List expectedSpans, + int expectedTotal, + List unexpectedSpans, String apiKey) { + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .queryParam("page", page) + .queryParam("size", size) + .queryParam("project_name", projectName) + .queryParam("project_id", projectId) + .queryParam("trace_id", traceId) + .queryParam("type", type) + .queryParam("filters", toURLEncodedQueryParam(filters)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + var actualPage = actualResponse.readEntity(Span.SpanPage.class); + var actualSpans = actualPage.content(); - assertThat(actualSpans.size()).isEqualTo(expectedSpans.size()); - assertThat(actualSpans) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS) - .containsExactlyElementsOf(expectedSpans); - assertIgnoredFields(actualSpans, expectedSpans); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + + assertThat(actualPage.page()).isEqualTo(page); + assertThat(actualPage.size()).isEqualTo(expectedSpans.size()); + assertThat(actualPage.total()).isEqualTo(expectedTotal); + assertThat(actualSpans.size()).isEqualTo(expectedSpans.size()); + assertThat(actualSpans) + .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS) + .containsExactlyElementsOf(expectedSpans); + assertIgnoredFields(actualSpans, expectedSpans); + + if (!unexpectedSpans.isEmpty()) { assertThat(actualSpans) .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS) .doesNotContainAnyElementsOf(unexpectedSpans); } } + } - private String toURLEncodedQueryParam(List filters) { - return CollectionUtils.isEmpty(filters) - ? null - : URLEncoder.encode(JsonUtils.writeValueAsString(filters), StandardCharsets.UTF_8); - } + private String toURLEncodedQueryParam(List filters) { + return CollectionUtils.isEmpty(filters) + ? null + : URLEncoder.encode(JsonUtils.writeValueAsString(filters), StandardCharsets.UTF_8); + } - private void assertIgnoredFields(List actualSpans, List expectedSpans) { - for (int i = 0; i < actualSpans.size(); i++) { - var actualSpan = actualSpans.get(i); - var expectedSpan = expectedSpans.get(i); - var expectedFeedbackScores = expectedSpan.feedbackScores() == null - ? null - : expectedSpan.feedbackScores().reversed(); - assertThat(actualSpan.projectId()).isNotNull(); - assertThat(actualSpan.projectName()).isNull(); - assertThat(actualSpan.createdAt()).isAfter(expectedSpan.createdAt()); - assertThat(actualSpan.lastUpdatedAt()).isAfter(expectedSpan.lastUpdatedAt()); - assertThat(actualSpan.feedbackScores()) - .usingRecursiveComparison( - RecursiveComparisonConfiguration.builder() - .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) - .withIgnoredFields(IGNORED_FIELDS_SCORES) - .build()) - .isEqualTo(expectedFeedbackScores); - - if (actualSpan.feedbackScores() != null) { - actualSpan.feedbackScores().forEach(feedbackScore -> { - assertThat(feedbackScore.createdAt()).isAfter(expectedSpan.createdAt()); - assertThat(feedbackScore.lastUpdatedAt()).isAfter(expectedSpan.lastUpdatedAt()); - assertThat(feedbackScore.createdBy()).isEqualTo(USER); - assertThat(feedbackScore.lastUpdatedBy()).isEqualTo(USER); - }); - } + private void assertIgnoredFields(List actualSpans, List expectedSpans) { + for (int i = 0; i < actualSpans.size(); i++) { + var actualSpan = actualSpans.get(i); + var expectedSpan = expectedSpans.get(i); + var expectedFeedbackScores = expectedSpan.feedbackScores() == null + ? null + : expectedSpan.feedbackScores().reversed(); + assertThat(actualSpan.projectId()).isNotNull(); + assertThat(actualSpan.projectName()).isNull(); + assertThat(actualSpan.createdAt()).isAfter(expectedSpan.createdAt()); + assertThat(actualSpan.lastUpdatedAt()).isAfter(expectedSpan.lastUpdatedAt()); + assertThat(actualSpan.feedbackScores()) + .usingRecursiveComparison( + RecursiveComparisonConfiguration.builder() + .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .withIgnoredFields(IGNORED_FIELDS_SCORES) + .build()) + .isEqualTo(expectedFeedbackScores); + + if (actualSpan.feedbackScores() != null) { + actualSpan.feedbackScores().forEach(feedbackScore -> { + assertThat(feedbackScore.createdAt()).isAfter(expectedSpan.createdAt()); + assertThat(feedbackScore.lastUpdatedAt()).isAfter(expectedSpan.lastUpdatedAt()); + assertThat(feedbackScore.createdBy()).isEqualTo(USER); + assertThat(feedbackScore.lastUpdatedBy()).isEqualTo(USER); + }); } } - - private List updateFeedbackScore(List feedbackScores, int index, double val) { - feedbackScores.set(index, feedbackScores.get(index).toBuilder() - .value(BigDecimal.valueOf(val)) - .build()); - return feedbackScores; - } - - private List updateFeedbackScore( - List destination, List source, int index) { - destination.set(index, source.get(index).toBuilder().build()); - return destination; - } } private UUID createAndAssert(Span expectedSpan, String apiKey, String workspaceName) { @@ -3091,7 +3107,7 @@ private UUID createAndAssert(Span expectedSpan, String apiKey, String workspaceN if (expectedSpan.id() != null) { expectedSpanId = expectedSpan.id(); } else { - expectedSpanId = UUID.fromString(actualHeaderString.substring(actualHeaderString.lastIndexOf('/') + 1)); + expectedSpanId = TestUtils.getIdFromLocation(actualResponse.getLocation()); } assertThat(actualHeaderString).isEqualTo(URL_TEMPLATE.formatted(baseURI) @@ -3189,6 +3205,130 @@ void createWhenTryingToCreateSpanTwice() { } } + @Nested + @DisplayName("Batch:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class BatchInsert { + + @Test + void batch__whenCreateSpans__thenReturnNoContent() { + + String projectName = UUID.randomUUID().toString(); + UUID projectId = createProject(projectName, TEST_WORKSPACE, API_KEY); + + var expectedSpans = IntStream.range(0, 1000) + .mapToObj(i -> podamFactory.manufacturePojo(Span.class).toBuilder() + .projectId(projectId) + .projectName(projectName) + .parentSpanId(null) + .feedbackScores(null) + .build()) + .toList(); + + batchCreateAndAssert(expectedSpans, API_KEY, TEST_WORKSPACE); + + getAndAssertPage(TEST_WORKSPACE, projectName, List.of(), List.of(), expectedSpans.reversed(), List.of(), + API_KEY); + } + + @Test + void batch__whenSendingMultipleSpansWithSameId__thenReturn422() { + var expectedSpans = List.of(podamFactory.manufacturePojo(Span.class).toBuilder() + .projectId(null) + .parentSpanId(null) + .feedbackScores(null) + .build()); + + var expectedSpan = expectedSpans.getFirst().toBuilder() + .tags(Set.of()) + .endTime(Instant.now()) + .output(JsonUtils.getJsonNodeFromString("{ \"output\": \"data\"}")) + .build(); + + List expectedSpans1 = List.of(expectedSpans.getFirst(), expectedSpan); + + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("batch") + .request() + .header(HttpHeaders.AUTHORIZATION, API_KEY) + .header(WORKSPACE_HEADER, TEST_WORKSPACE) + .post(Entity.json(new SpanBatch(expectedSpans1)))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + + var errorMessage = actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class); + assertThat(errorMessage.getMessage()).isEqualTo("Duplicate span id '%s'".formatted(expectedSpan.id())); + } + } + + @ParameterizedTest + @MethodSource + void batch__whenBatchIsInvalid__thenReturn422(List spans, String errorMessage) { + + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("batch") + .request() + .header(HttpHeaders.AUTHORIZATION, API_KEY) + .header(WORKSPACE_HEADER, TEST_WORKSPACE) + .post(Entity.json(new SpanBatch(spans)))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + + var responseBody = actualResponse.readEntity(ErrorMessage.class); + assertThat(responseBody.errors()).contains(errorMessage); + } + } + + Stream batch__whenBatchIsInvalid__thenReturn422() { + return Stream.of( + Arguments.of(List.of(), "spans size must be between 1 and 1000"), + Arguments.of(IntStream.range(0, 1001) + .mapToObj(i -> podamFactory.manufacturePojo(Span.class).toBuilder() + .projectId(null) + .parentSpanId(null) + .feedbackScores(null) + .build()) + .toList(), "spans size must be between 1 and 1000")); + } + + @Test + void batch__whenSendingMultipleSpansWithNoId__thenReturnNoContent() { + var newSpan = podamFactory.manufacturePojo(Span.class).toBuilder() + .projectId(null) + .id(null) + .parentSpanId(null) + .feedbackScores(null) + .build(); + + var expectedSpan = newSpan.toBuilder() + .tags(Set.of()) + .endTime(Instant.now()) + .output(JsonUtils.getJsonNodeFromString("{ \"output\": \"data\"}")) + .build(); + + List expectedSpans = List.of(newSpan, expectedSpan); + + batchCreateAndAssert(expectedSpans, API_KEY, TEST_WORKSPACE); + } + + } + + private void batchCreateAndAssert(List expectedSpans, String apiKey, String workspaceName) { + + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("batch") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(new SpanBatch(expectedSpans)))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204); + assertThat(actualResponse.hasEntity()).isFalse(); + } + } + private Span getAndAssert(Span expectedSpan, String apiKey, String workspaceName) { try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .path(expectedSpan.id().toString()) @@ -3389,7 +3529,7 @@ void when__spanDoesNotExist__thenReturnCreateIt() { var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); - var projectId = getProjectId(client, spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); + var projectId = getProjectId(spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); var actualEntity = actualResponse.readEntity(Span.class); assertThat(actualEntity.id()).isEqualTo(id); @@ -3433,7 +3573,7 @@ void when__spanUpdateAndInsertAreProcessedOutOfOther__thenReturnSpan() { var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); - var projectId = getProjectId(client, spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); + var projectId = getProjectId(spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); var actualEntity = actualResponse.readEntity(Span.class); assertThat(actualEntity.id()).isEqualTo(id); @@ -3641,7 +3781,7 @@ void when__multipleSpanUpdateAndInsertAreProcessedOutOfOtherAndConcurrent__thenR var actualEntity = actualResponse.readEntity(Span.class); assertThat(actualEntity.id()).isEqualTo(id); - var projectId = getProjectId(client, projectName, TEST_WORKSPACE, API_KEY); + var projectId = getProjectId(projectName, TEST_WORKSPACE, API_KEY); assertThat(actualEntity.projectId()).isEqualTo(projectId); assertThat(actualEntity.traceId()).isEqualTo(spanUpdate1.traceId()); @@ -3680,7 +3820,7 @@ void update__whenTagsIsEmpty__thenAcceptUpdate() { runPatchAndAssertStatus(expectedSpan.id(), spanUpdate, API_KEY, TEST_WORKSPACE); - UUID projectId = getProjectId(client, spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); Span updatedSpan = expectedSpan.toBuilder() .tags(spanUpdate.tags()) @@ -3713,7 +3853,7 @@ void update__whenMetadataIsEmpty__thenAcceptUpdate() { runPatchAndAssertStatus(expectedSpan.id(), spanUpdate, API_KEY, TEST_WORKSPACE); - UUID projectId = getProjectId(client, spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); Span updatedSpan = expectedSpan.toBuilder() .metadata(metadata) @@ -3746,7 +3886,7 @@ void update__whenInputIsEmpty__thenAcceptUpdate() { runPatchAndAssertStatus(expectedSpan.id(), spanUpdate, API_KEY, TEST_WORKSPACE); - UUID projectId = getProjectId(client, spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); Span updatedSpan = expectedSpan.toBuilder() .input(input) @@ -3778,7 +3918,7 @@ void update__whenOutputIsEmpty__thenAcceptUpdate() { runPatchAndAssertStatus(expectedSpan.id(), spanUpdate, API_KEY, TEST_WORKSPACE); - UUID projectId = getProjectId(client, spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(spanUpdate.projectName(), TEST_WORKSPACE, API_KEY); Span updatedSpan = expectedSpan.toBuilder() .output(output) @@ -3800,7 +3940,7 @@ void update__whenUpdatingUsingProjectId__thenAcceptUpdate() { createAndAssert(expectedSpan, API_KEY, TEST_WORKSPACE); - var projectId = getProjectId(client, expectedSpan.projectName(), TEST_WORKSPACE, API_KEY); + var projectId = getProjectId(expectedSpan.projectName(), TEST_WORKSPACE, API_KEY); var spanUpdate = podamFactory.manufacturePojo(SpanUpdate.class).toBuilder() .traceId(expectedSpan.traceId()) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java index 74687bfee5..aa5e155592 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java @@ -7,6 +7,7 @@ import com.comet.opik.api.Project; import com.comet.opik.api.ScoreSource; import com.comet.opik.api.Trace; +import com.comet.opik.api.TraceBatch; import com.comet.opik.api.TraceUpdate; import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.api.filter.Filter; @@ -21,6 +22,7 @@ import com.comet.opik.api.resources.utils.MySQLContainerUtils; import com.comet.opik.api.resources.utils.RedisContainerUtils; import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; import com.comet.opik.api.resources.utils.WireMockUtils; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.podam.PodamFactoryUtils; @@ -162,7 +164,7 @@ void tearDownAll() { wireMock.server().stop(); } - private UUID getProjectId(ClientSupport client, String projectName, String workspaceName, String apiKey) { + private UUID getProjectId(String projectName, String workspaceName, String apiKey) { return client.target("%s/v1/private/projects".formatted(baseURI)) .queryParam("name", projectName) .request() @@ -177,6 +179,19 @@ private UUID getProjectId(ClientSupport client, String projectName, String works .id(); } + private UUID createProject(String projectName, String workspaceName, String apiKey) { + try (Response response = client.target("%s/v1/private/projects".formatted(baseURI)) + .queryParam("name", projectName) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(Project.builder().name(projectName).build()))) { + + assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(201); + return TestUtils.getIdFromLocation(response.getLocation()); + } + } + @Nested @DisplayName("Api Key Authentication:") @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -837,7 +852,7 @@ void getByProjectName__whenProjectIdIsNotEmpty__thenReturnTracesByProjectId() { .feedbackScores(null) .build(), apiKey, workspaceName); - UUID projectId = getProjectId(client, projectName, workspaceName, apiKey); + UUID projectId = getProjectId(projectName, workspaceName, apiKey); var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) .queryParam("workspace_name", workspaceName) @@ -2635,93 +2650,96 @@ void getByProjectName__whenFilterInvalidValueOrKeyForFieldType__thenReturn400(Fi var actualError = actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class); assertThat(actualError).isEqualTo(expectedError); } + } - private void getAndAssertPage(String workspaceName, String projectName, List filters, - List traces, - List expectedTraces, List unexpectedTraces, String apiKey) { - int page = 1; - int size = traces.size() + expectedTraces.size() + unexpectedTraces.size(); - getAndAssertPage(page, size, projectName, filters, expectedTraces, unexpectedTraces, - workspaceName, apiKey); - } + private void getAndAssertPage(String workspaceName, String projectName, List filters, + List traces, + List expectedTraces, List unexpectedTraces, String apiKey) { + int page = 1; + int size = traces.size() + expectedTraces.size() + unexpectedTraces.size(); + getAndAssertPage(page, size, projectName, filters, expectedTraces, unexpectedTraces, + workspaceName, apiKey); + } - private void getAndAssertPage(int page, int size, String projectName, List filters, - List expectedTraces, List unexpectedTraces, String workspaceName, String apiKey) { - var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) - .queryParam("page", page) - .queryParam("size", size) - .queryParam("project_name", projectName) - .queryParam("filters", toURLEncodedQueryParam(filters)) - .request() - .header(HttpHeaders.AUTHORIZATION, apiKey) - .header(WORKSPACE_HEADER, workspaceName) - .get(); + private void getAndAssertPage(int page, int size, String projectName, List filters, + List expectedTraces, List unexpectedTraces, String workspaceName, String apiKey) { + var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .queryParam("page", page) + .queryParam("size", size) + .queryParam("project_name", projectName) + .queryParam("filters", toURLEncodedQueryParam(filters)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get(); - assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); - var actualPage = actualResponse.readEntity(Trace.TracePage.class); - var actualTraces = actualPage.content(); + var actualPage = actualResponse.readEntity(Trace.TracePage.class); + var actualTraces = actualPage.content(); - assertThat(actualPage.page()).isEqualTo(page); - assertThat(actualPage.size()).isEqualTo(expectedTraces.size()); - assertThat(actualPage.total()).isEqualTo(expectedTraces.size()); - assertThat(actualTraces) - .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_LIST) - .containsExactlyElementsOf(expectedTraces); - assertIgnoredFields(actualTraces, expectedTraces); + assertThat(actualPage.page()).isEqualTo(page); + assertThat(actualPage.size()).isEqualTo(expectedTraces.size()); + assertThat(actualPage.total()).isEqualTo(expectedTraces.size()); + assertThat(actualTraces) + .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_LIST) + .containsExactlyElementsOf(expectedTraces); + assertIgnoredFields(actualTraces, expectedTraces); + + if (!unexpectedTraces.isEmpty()) { assertThat(actualTraces) .usingRecursiveFieldByFieldElementComparatorIgnoringFields(IGNORED_FIELDS_LIST) .doesNotContainAnyElementsOf(unexpectedTraces); } + } - private String toURLEncodedQueryParam(List filters) { - return URLEncoder.encode(JsonUtils.writeValueAsString(filters), StandardCharsets.UTF_8); - } + private String toURLEncodedQueryParam(List filters) { + return URLEncoder.encode(JsonUtils.writeValueAsString(filters), StandardCharsets.UTF_8); + } - private void assertIgnoredFields(List actualTraces, List expectedTraces) { - for (int i = 0; i < actualTraces.size(); i++) { - var actualTrace = actualTraces.get(i); - var expectedTrace = expectedTraces.get(i); - var expectedFeedbackScores = expectedTrace.feedbackScores() == null - ? null - : expectedTrace.feedbackScores().reversed(); - assertThat(actualTrace.projectId()).isNotNull(); - assertThat(actualTrace.projectName()).isNull(); - assertThat(actualTrace.createdAt()).isAfter(expectedTrace.createdAt()); - assertThat(actualTrace.lastUpdatedAt()).isAfter(expectedTrace.lastUpdatedAt()); - assertThat(actualTrace.lastUpdatedBy()).isEqualTo(USER); - assertThat(actualTrace.lastUpdatedBy()).isEqualTo(USER); - assertThat(actualTrace.feedbackScores()) - .usingRecursiveComparison( - RecursiveComparisonConfiguration.builder() - .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) - .withIgnoredFields(IGNORED_FIELDS) - .build()) - .isEqualTo(expectedFeedbackScores); - - if (expectedTrace.feedbackScores() != null) { - actualTrace.feedbackScores().forEach(feedbackScore -> { - assertThat(feedbackScore.createdAt()).isAfter(expectedTrace.createdAt()); - assertThat(feedbackScore.lastUpdatedAt()).isAfter(expectedTrace.createdAt()); - assertThat(feedbackScore.lastUpdatedBy()).isEqualTo(USER); - assertThat(feedbackScore.lastUpdatedBy()).isEqualTo(USER); - }); - } + private void assertIgnoredFields(List actualTraces, List expectedTraces) { + for (int i = 0; i < actualTraces.size(); i++) { + var actualTrace = actualTraces.get(i); + var expectedTrace = expectedTraces.get(i); + var expectedFeedbackScores = expectedTrace.feedbackScores() == null + ? null + : expectedTrace.feedbackScores().reversed(); + assertThat(actualTrace.projectId()).isNotNull(); + assertThat(actualTrace.projectName()).isNull(); + assertThat(actualTrace.createdAt()).isAfter(expectedTrace.createdAt()); + assertThat(actualTrace.lastUpdatedAt()).isAfter(expectedTrace.lastUpdatedAt()); + assertThat(actualTrace.lastUpdatedBy()).isEqualTo(USER); + assertThat(actualTrace.lastUpdatedBy()).isEqualTo(USER); + assertThat(actualTrace.feedbackScores()) + .usingRecursiveComparison( + RecursiveComparisonConfiguration.builder() + .withComparatorForType(BigDecimal::compareTo, BigDecimal.class) + .withIgnoredFields(IGNORED_FIELDS) + .build()) + .isEqualTo(expectedFeedbackScores); + + if (expectedTrace.feedbackScores() != null) { + actualTrace.feedbackScores().forEach(feedbackScore -> { + assertThat(feedbackScore.createdAt()).isAfter(expectedTrace.createdAt()); + assertThat(feedbackScore.lastUpdatedAt()).isAfter(expectedTrace.createdAt()); + assertThat(feedbackScore.lastUpdatedBy()).isEqualTo(USER); + assertThat(feedbackScore.lastUpdatedBy()).isEqualTo(USER); + }); } } + } - private List updateFeedbackScore(List feedbackScores, int index, double val) { - feedbackScores.set(index, feedbackScores.get(index).toBuilder() - .value(BigDecimal.valueOf(val)) - .build()); - return feedbackScores; - } + private List updateFeedbackScore(List feedbackScores, int index, double val) { + feedbackScores.set(index, feedbackScores.get(index).toBuilder() + .value(BigDecimal.valueOf(val)) + .build()); + return feedbackScores; + } - private List updateFeedbackScore( - List destination, List source, int index) { - destination.set(index, source.get(index).toBuilder().build()); - return destination; - } + private List updateFeedbackScore( + List destination, List source, int index) { + destination.set(index, source.get(index).toBuilder().build()); + return destination; } @Nested @@ -2808,8 +2826,7 @@ private UUID create(Trace trace, String apiKey, String workspaceName) { assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - return UUID.fromString(actualResponse.getHeaderString("Location") - .substring(actualResponse.getHeaderString("Location").lastIndexOf('/') + 1)); + return TestUtils.getIdFromLocation(actualResponse.getLocation()); } } @@ -2893,7 +2910,7 @@ void create() { assertThat(actualResponse.getHeaderString("Location")).matches(Pattern.compile(URL_PATTERN)); } - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); getAndAssert(trace, id, projectId, now, API_KEY, TEST_WORKSPACE); } @@ -2921,8 +2938,8 @@ void create__whenCreatingTracesWithDifferentWorkspacesNames__thenReturnCreatedTr var createdTrace2 = Instant.now(); UUID id2 = TracesResourceTest.this.create(trace2, API_KEY, TEST_WORKSPACE); - UUID projectId1 = getProjectId(client, DEFAULT_PROJECT, TEST_WORKSPACE, API_KEY); - UUID projectId2 = getProjectId(client, projectName, TEST_WORKSPACE, API_KEY); + UUID projectId1 = getProjectId(DEFAULT_PROJECT, TEST_WORKSPACE, API_KEY); + UUID projectId2 = getProjectId(projectName, TEST_WORKSPACE, API_KEY); getAndAssert(trace1, id1, projectId1, createdTrace1, API_KEY, TEST_WORKSPACE); getAndAssert(trace2, id2, projectId2, createdTrace2, API_KEY, TEST_WORKSPACE); @@ -2953,10 +2970,9 @@ void create__whenIdComesFromClient__thenAcceptAndUseId() { assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); - String actualId = actualResponse.getLocation().toString() - .substring(actualResponse.getLocation().toString().lastIndexOf('/') + 1); + UUID actualId = TestUtils.getIdFromLocation(actualResponse.getLocation()); - assertThat(UUID.fromString(actualId)).isEqualTo(traceId); + assertThat(actualId).isEqualTo(traceId); } } @@ -3027,7 +3043,7 @@ void create__whenProjectNameIsNull__thenAcceptAndUseDefaultProject() { var actualResponse = getById(id, TEST_WORKSPACE, API_KEY); assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); - UUID projectId = getProjectId(client, DEFAULT_PROJECT, TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(DEFAULT_PROJECT, TEST_WORKSPACE, API_KEY); var actualEntity = actualResponse.readEntity(Trace.class); assertThat(actualEntity.projectId()).isEqualTo(projectId); @@ -3035,6 +3051,128 @@ void create__whenProjectNameIsNull__thenAcceptAndUseDefaultProject() { } + @Nested + @DisplayName("Batch:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class BatchInsert { + + @Test + void batch__whenCreateTraces__thenReturnNoContent() { + + String projectName = UUID.randomUUID().toString(); + + UUID projectId = createProject(projectName, TEST_WORKSPACE, API_KEY); + + var expectedTraces = IntStream.range(0, 1000) + .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() + .projectName(projectName) + .projectId(projectId) + .endTime(null) + .feedbackScores(null) + .build()) + .toList(); + + batchCreateAndAssert(expectedTraces, API_KEY, TEST_WORKSPACE); + + getAndAssertPage(TEST_WORKSPACE, projectName, List.of(), List.of(), expectedTraces.reversed(), List.of(), + API_KEY); + } + + @Test + void batch__whenSendingMultipleTracesWithSameId__thenReturn422() { + var trace = factory.manufacturePojo(Trace.class).toBuilder() + .projectId(null) + .feedbackScores(null) + .build(); + + var expectedTrace = trace.toBuilder() + .tags(Set.of()) + .endTime(Instant.now()) + .output(JsonUtils.getJsonNodeFromString("{ \"output\": \"data\"}")) + .build(); + + List traces = List.of(trace, expectedTrace); + + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("batch") + .request() + .header(HttpHeaders.AUTHORIZATION, API_KEY) + .header(WORKSPACE_HEADER, TEST_WORKSPACE) + .post(Entity.json(new TraceBatch(traces)))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + + var errorMessage = actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class); + assertThat(errorMessage.getMessage()).isEqualTo("Duplicate trace id '%s'".formatted(trace.id())); + } + } + + @ParameterizedTest + @MethodSource + void batch__whenBatchIsInvalid__thenReturn422(List traces, String errorMessage) { + + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("batch") + .request() + .header(HttpHeaders.AUTHORIZATION, API_KEY) + .header(WORKSPACE_HEADER, TEST_WORKSPACE) + .post(Entity.json(new TraceBatch(traces)))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(422); + assertThat(actualResponse.hasEntity()).isTrue(); + + var responseBody = actualResponse.readEntity(ErrorMessage.class); + assertThat(responseBody.errors()).contains(errorMessage); + } + } + + Stream batch__whenBatchIsInvalid__thenReturn422() { + return Stream.of( + Arguments.of(List.of(), "traces size must be between 1 and 1000"), + Arguments.of(IntStream.range(0, 1001) + .mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder() + .projectId(null) + .feedbackScores(null) + .build()) + .toList(), "traces size must be between 1 and 1000")); + } + + @Test + void batch__whenSendingMultipleTracesWithNoId__thenReturnNoContent() { + var newTrace = factory.manufacturePojo(Trace.class).toBuilder() + .projectId(null) + .id(null) + .feedbackScores(null) + .build(); + + var expectedTrace = newTrace.toBuilder() + .tags(Set.of()) + .endTime(Instant.now()) + .output(JsonUtils.getJsonNodeFromString("{ \"output\": \"data\"}")) + .build(); + + List expectedTraces = List.of(newTrace, expectedTrace); + + batchCreateAndAssert(expectedTraces, API_KEY, TEST_WORKSPACE); + } + + private void batchCreateAndAssert(List traces, String apiKey, String workspaceName) { + + try (var actualResponse = client.target(URL_TEMPLATE.formatted(baseURI)) + .path("batch") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.json(new TraceBatch(traces)))) { + + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(204); + assertThat(actualResponse.hasEntity()).isFalse(); + } + } + + } + @Nested @DisplayName("Delete:") @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -3157,7 +3295,7 @@ void when__traceDoesNotExist__thenReturnCreateIt() { assertThat(actualEntity.metadata()).isEqualTo(traceUpdate.metadata()); assertThat(actualEntity.tags()).isEqualTo(traceUpdate.tags()); - UUID projectId = getProjectId(client, traceUpdate.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(traceUpdate.projectName(), TEST_WORKSPACE, API_KEY); assertThat(actualEntity.name()).isEmpty(); assertThat(actualEntity.startTime()).isEqualTo(Instant.EPOCH); @@ -3375,7 +3513,7 @@ void update__whenTagsIsEmpty__thenAcceptUpdate() { runPatchAndAssertStatus(id, traceUpdate, API_KEY, TEST_WORKSPACE); - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); Trace actualTrace = getAndAssert(trace, id, projectId, trace.createdAt().minusMillis(1), API_KEY, TEST_WORKSPACE); @@ -3396,7 +3534,7 @@ void update__whenMetadataIsEmpty__thenAcceptUpdate() { runPatchAndAssertStatus(id, traceUpdate, API_KEY, TEST_WORKSPACE); - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); Trace actualTrace = getAndAssert(trace.toBuilder().metadata(metadata).build(), id, projectId, trace.createdAt().minusMillis(1), API_KEY, TEST_WORKSPACE); @@ -3417,7 +3555,7 @@ void update__whenInputIsEmpty__thenAcceptUpdate() { runPatchAndAssertStatus(id, traceUpdate, API_KEY, TEST_WORKSPACE); - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); Trace actualTrace = getAndAssert(trace.toBuilder().input(input).build(), id, projectId, trace.createdAt().minusMillis(1), API_KEY, TEST_WORKSPACE); @@ -3438,7 +3576,7 @@ void update__whenOutputIsEmpty__thenAcceptUpdate() { runPatchAndAssertStatus(id, traceUpdate, API_KEY, TEST_WORKSPACE); - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); Trace actualTrace = getAndAssert(trace.toBuilder().output(output).build(), id, projectId, trace.createdAt().minusMillis(1), API_KEY, TEST_WORKSPACE); @@ -3450,7 +3588,7 @@ void update__whenOutputIsEmpty__thenAcceptUpdate() { @DisplayName("when updating using projectId, then accept update") void update__whenUpdatingUsingProjectId__thenAcceptUpdate() { - var projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + var projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); var traceUpdate = factory.manufacturePojo(TraceUpdate.class).toBuilder() .projectId(projectId) @@ -3578,7 +3716,7 @@ void feedback__whenFeedbackWithoutCategoryNameOrReason__thenReturnNoContent() { create(id, score, TEST_WORKSPACE, API_KEY); - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); var actualEntity = getAndAssert(trace, id, projectId, now, API_KEY, TEST_WORKSPACE); @@ -3615,7 +3753,7 @@ void feedback__whenFeedbackWithCategoryNameOrReason__thenReturnNoContent() { create(id, score, TEST_WORKSPACE, API_KEY); - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); Trace actualEntity = getAndAssert(trace, id, projectId, now, API_KEY, TEST_WORKSPACE); @@ -3653,7 +3791,7 @@ void feedback__whenOverridingFeedbackValue__thenReturnNoContent() { FeedbackScore newScore = score.toBuilder().value(BigDecimal.valueOf(2)).build(); create(id, newScore, TEST_WORKSPACE, API_KEY); - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); var actualEntity = getAndAssert(trace, id, projectId, now, API_KEY, TEST_WORKSPACE); assertThat(actualEntity.feedbackScores()).hasSize(1); @@ -3846,8 +3984,8 @@ void feedback() { assertThat(actualResponse.hasEntity()).isFalse(); } - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); - UUID projectId2 = getProjectId(client, trace2.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId2 = getProjectId(trace2.projectName(), TEST_WORKSPACE, API_KEY); var actualTrace1 = getAndAssert(trace, id, projectId, now, API_KEY, TEST_WORKSPACE); var actualTrace2 = getAndAssert(trace2, id2, projectId2, now, API_KEY, TEST_WORKSPACE); @@ -3917,8 +4055,8 @@ void feedback__whenWorkspaceIsSpecified__thenReturnNoContent() { assertThat(actualResponse.hasEntity()).isFalse(); } - UUID projectId = getProjectId(client, DEFAULT_PROJECT, workspaceName, apiKey); - UUID projectId2 = getProjectId(client, projectName, workspaceName, apiKey); + UUID projectId = getProjectId(DEFAULT_PROJECT, workspaceName, apiKey); + UUID projectId2 = getProjectId(projectName, workspaceName, apiKey); var actualTrace1 = getAndAssert(expectedTrace1, id, projectId, now, apiKey, workspaceName); var actualTrace2 = getAndAssert(expectedTrace2, id2, projectId2, now, apiKey, workspaceName); @@ -3987,7 +4125,7 @@ void feedback__whenFeedbackWithoutCategoryNameOrReason__thenReturnNoContent() { assertThat(actualResponse.hasEntity()).isFalse(); } - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); var actualEntity = getAndAssert(trace, id, projectId, now, API_KEY, TEST_WORKSPACE); @@ -4038,7 +4176,7 @@ void feedback__whenFeedbackWithCategoryNameOrReason__thenReturnNoContent() { } var actualEntity = getAndAssert(expectedTrace, id, - getProjectId(client, expectedTrace.projectName(), TEST_WORKSPACE, API_KEY), now, API_KEY, + getProjectId(expectedTrace.projectName(), TEST_WORKSPACE, API_KEY), now, API_KEY, TEST_WORKSPACE); assertThat(actualEntity.feedbackScores()).hasSize(1); @@ -4100,7 +4238,7 @@ void feedback__whenOverridingFeedbackValue__thenReturnNoContent() { assertThat(actualResponse.hasEntity()).isFalse(); } - UUID projectId = getProjectId(client, trace.projectName(), TEST_WORKSPACE, API_KEY); + UUID projectId = getProjectId(trace.projectName(), TEST_WORKSPACE, API_KEY); var actualEntity = getAndAssert(trace, id, projectId, now, API_KEY, TEST_WORKSPACE); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java new file mode 100644 index 0000000000..faf96f9979 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/DummyLockService.java @@ -0,0 +1,18 @@ +package com.comet.opik.domain; + +import com.comet.opik.infrastructure.redis.LockService; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class DummyLockService implements LockService { + + @Override + public Mono executeWithLock(LockService.Lock lock, Mono action) { + return action; + } + + @Override + public Flux executeWithLock(LockService.Lock lock, Flux action) { + return action; + } +} \ No newline at end of file diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java index 7570c2406c..c5e404b98a 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/SpanServiceTest.java @@ -10,7 +10,6 @@ import jakarta.ws.rs.NotFoundException; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import uk.co.jemos.podam.api.PodamFactory; @@ -26,18 +25,7 @@ @Disabled class SpanServiceTest { - private static final LockService DUMMY_LOCK_SERVICE = new LockService() { - - @Override - public Mono executeWithLock(Lock lock, Mono action) { - return action; - } - - @Override - public Flux executeWithLock(Lock lock, Flux action) { - return action; - } - }; + private static final LockService DUMMY_LOCK_SERVICE = new DummyLockService(); private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java index 78d0a40666..2188190fdc 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/TraceServiceImplTest.java @@ -19,7 +19,6 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import uk.co.jemos.podam.api.PodamFactory; import uk.co.jemos.podam.api.PodamFactoryImpl; @@ -40,18 +39,7 @@ @ExtendWith(MockitoExtension.class) class TraceServiceImplTest { - public static final LockService DUMMY_LOCK_SERVICE = new LockService() { - - @Override - public Mono executeWithLock(Lock lock, Mono action) { - return action; - } - - @Override - public Flux executeWithLock(Lock lock, Flux action) { - return action; - } - }; + public static final LockService DUMMY_LOCK_SERVICE = new DummyLockService(); private TraceServiceImpl traceService; diff --git a/build_and_run.sh b/build_and_run.sh index 89e740ee91..4105557a77 100755 --- a/build_and_run.sh +++ b/build_and_run.sh @@ -15,6 +15,7 @@ BUILD=true FE_BUILD=true HELM_UPDATE=true LOCAL_FE=false +LOCAL_FE_PORT=${LOCAL_FE_PORT:-5174} CLOUD_VERSION=false function show_help() { @@ -93,15 +94,16 @@ echo "### Install Opik using latest versions" cd deployment/helm_chart/opik VERSION=latest -LOCAL_FE_FLAGS="" if [[ "${LOCAL_FE}" == "true" ]]; then - if [[ "$OSTYPE" == "darwin"* ]]; then - IP_ADDRESS=$(ifconfig | grep 'inet ' | grep -v '127.0.0.1' | awk '{print $2}' | head -n 1) - else - IP_ADDRESS=$(hostname -I | awk '{print $1}') + LOCAL_FE_FLAGS="--set localFE=true" + if [ -z "${LOCAL_FE_HOST}" ] ; then + if [[ "$OSTYPE" == "darwin"* ]]; then + LOCAL_FE_HOST=$(ifconfig | grep 'inet ' | grep -vF '127.0.0.1' | awk '{print $2}' | head -n 1) + else + LOCAL_FE_HOST=$(hostname -I | awk '{print $1}') + fi fi - - LOCAL_FE_FLAGS="--set localFE=true --set localFEAddress=$IP_ADDRESS"; + LOCAL_FE_FLAGS="${LOCAL_FE_FLAGS} --set localFEAddress=${LOCAL_FE_HOST}:${LOCAL_FE_PORT}"; fi CLOUD_VERSION_FLAGS="" @@ -156,24 +158,27 @@ while true; do sleep $INTERVAL done +echo "### Waiting for pods" +kubectl wait --for=condition=ready pod --all + echo "### Port-forward Opik Frontend to local host" # remove the previous port-forward -ps -ef | grep "svc/${OPIK_FRONTEND} ${OPIK_FRONTEND_PORT}" | grep -v grep | awk '{print $2}' | xargs kill || true -kubectl port-forward svc/${OPIK_FRONTEND} ${OPIK_FRONTEND_PORT} & +ps -ef | grep "svc/${OPIK_FRONTEND} ${OPIK_FRONTEND_PORT}" | grep -v grep | awk '{print $2}' | xargs kill 2>/dev/null|| true +kubectl port-forward svc/${OPIK_FRONTEND} ${OPIK_FRONTEND_PORT} > /dev/null 2>&1 & echo "### Port-forward Open API to local host" # remove the previous port-forward -ps -ef | grep "svc/${OPIK_BACKEND} ${OPIK_OPENAPI_PORT}" | grep -v grep | awk '{print $2}' | xargs kill || true -kubectl port-forward svc/${OPIK_BACKEND} ${OPIK_OPENAPI_PORT} & +ps -ef | grep "svc/${OPIK_BACKEND} ${OPIK_OPENAPI_PORT}" | grep -v grep | awk '{print $2}' | xargs kill 2>/dev/null|| true +kubectl port-forward svc/${OPIK_BACKEND} ${OPIK_OPENAPI_PORT} > /dev/null 2>&1 & echo "### Port-forward Clickhouse to local host" # remove the previous port-forward -ps -ef | grep "svc/${OPIK_CLICKHOUSE} ${OPIK_CLICKHOUSE_PORT}" | grep -v grep | awk '{print $2}' | xargs kill || true -kubectl port-forward svc/${OPIK_CLICKHOUSE} ${OPIK_CLICKHOUSE_PORT} & +ps -ef | grep "svc/${OPIK_CLICKHOUSE} ${OPIK_CLICKHOUSE_PORT}" | grep -v grep | awk '{print $2}' | xargs kill 2>/dev/null|| true +kubectl port-forward svc/${OPIK_CLICKHOUSE} ${OPIK_CLICKHOUSE_PORT} > /dev/null 2>&1 & echo "### Port-forward MySQL to local host" # remove the previous port-forward -ps -ef | grep "svc/${OPIK_MYSQL} ${OPIK_MYSQL_PORT}" | grep -v grep | awk '{print $2}' | xargs kill || true -kubectl port-forward svc/${OPIK_MYSQL} ${OPIK_MYSQL_PORT} & +ps -ef | grep "svc/${OPIK_MYSQL} ${OPIK_MYSQL_PORT}" | grep -v grep | awk '{print $2}' | xargs kill 2>/dev/null|| true +kubectl port-forward svc/${OPIK_MYSQL} ${OPIK_MYSQL_PORT} > /dev/null 2>&1 & echo "Now you can open your browser and connect http://localhost:${OPIK_FRONTEND_PORT}"