Skip to content

Commit

Permalink
[OPIK-210] Fix NPE in batch endpoint when project name is null (#351)
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagohora authored Oct 7, 2024
1 parent 9d0190a commit b4b57d1
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,12 @@ public Mono<Long> create(@NonNull SpanBatch batch) {
List<String> projectNames = batch.spans()
.stream()
.map(Span::projectName)
.map(WorkspaceUtils::getProjectName)
.distinct()
.toList();

Mono<List<Span>> resolveProjects = Flux.fromIterable(projectNames)
.flatMap(this::resolveProject)
.flatMap(this::getOrCreateProject)
.collectList()
.map(projects -> bindSpanToProjectAndId(batch, projects))
.subscribeOn(Schedulers.boundedElastic());
Expand All @@ -293,7 +294,4 @@ private List<Span> bindSpanToProjectAndId(SpanBatch batch, List<Project> project
.toList();
}

private Mono<Project> resolveProject(String projectName) {
return getOrCreateProject(WorkspaceUtils.getProjectName(projectName));
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.comet.opik.domain;

import com.comet.opik.api.Trace;
import com.comet.opik.api.TraceCountResponse;
import com.comet.opik.api.TraceSearchCriteria;
import com.comet.opik.api.TraceUpdate;
import com.comet.opik.domain.filter.FilterQueryBuilder;
Expand Down Expand Up @@ -36,6 +35,7 @@
import java.util.stream.Collectors;

import static com.comet.opik.api.Trace.TracePage;
import static com.comet.opik.api.TraceCountResponse.WorkspaceTraceCount;
import static com.comet.opik.domain.AsyncContextUtils.bindUserNameAndWorkspaceContext;
import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux;
import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToMono;
Expand Down Expand Up @@ -67,7 +67,7 @@ interface TraceDAO {

Mono<Long> batchInsert(List<Trace> traces, Connection connection);

Flux<TraceCountResponse.WorkspaceTraceCount> countTracesPerWorkspace(Connection connection);
Flux<WorkspaceTraceCount> countTracesPerWorkspace(Connection connection);
}

@Slf4j
Expand Down Expand Up @@ -877,12 +877,12 @@ private String getOrDefault(JsonNode value) {
}

@com.newrelic.api.agent.Trace(dispatcher = true)
public Flux<TraceCountResponse.WorkspaceTraceCount> countTracesPerWorkspace(Connection connection) {
public Flux<WorkspaceTraceCount> countTracesPerWorkspace(Connection connection) {

var statement = connection.createStatement(TRACE_COUNT_BY_WORKSPACE_ID);

return Mono.from(statement.execute())
.flatMapMany(result -> result.map((row, rowMetadata) -> TraceCountResponse.WorkspaceTraceCount.builder()
.flatMapMany(result -> result.map((row, rowMetadata) -> WorkspaceTraceCount.builder()
.workspace(row.get("workspace_id", String.class))
.traceCount(row.get("trace_count", Integer.class)).build()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,12 @@ public Mono<Long> create(TraceBatch batch) {
List<String> projectNames = batch.traces()
.stream()
.map(Trace::projectName)
.map(WorkspaceUtils::getProjectName)
.distinct()
.toList();

Mono<List<Trace>> resolveProjects = Flux.fromIterable(projectNames)
.flatMap(this::resolveProject)
.flatMap(this::getOrCreateProject)
.collectList()
.map(projects -> bindTraceToProjectAndId(batch, projects))
.subscribeOn(Schedulers.boundedElastic());
Expand All @@ -132,10 +133,6 @@ private List<Trace> bindTraceToProjectAndId(TraceBatch batch, List<Project> proj
.toList();
}

private Mono<Project> resolveProject(String projectName) {
return getOrCreateProject(WorkspaceUtils.getProjectName(projectName));
}

private Mono<UUID> insertTrace(Trace newTrace, Project project, UUID id) {
//TODO: refactor to implement proper conflict resolution
return template.nonTransaction(connection -> dao.findById(id, connection))
Expand Down Expand Up @@ -327,6 +324,7 @@ public Mono<Boolean> validateTraceWorkspace(@NonNull String workspaceId, @NonNul
}

@Override
@com.newrelic.api.agent.Trace(dispatcher = true)
public Mono<TraceCountResponse> countTracesPerWorkspace() {
return template.stream(dao::countTracesPerWorkspace)
.collectList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ public class ClickHouseContainerUtils {
public static final String DATABASE_NAME_VARIABLE = "ANALYTICS_DB_DATABASE_NAME";

public static ClickHouseContainer newClickHouseContainer() {
return newClickHouseContainer(true);
}

public static ClickHouseContainer newClickHouseContainer(boolean reusable) {
// TODO: Use non-deprecated ClickHouseContainer: https://github.com/comet-ml/opik/issues/58
return new ClickHouseContainer(
DockerImageName.parse("clickhouse/clickhouse-server:24.3.8.13-alpine"))
.withReuse(true);
.withReuse(reusable);
}

public static DatabaseAnalyticsFactory newDatabaseAnalyticsFactory(ClickHouseContainer clickHouseContainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class UsageResourceTest {

private static final MySQLContainer<?> MYSQL_CONTAINER = MySQLContainerUtils.newMySQLContainer();

private static final ClickHouseContainer CLICK_HOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer();
private static final ClickHouseContainer CLICK_HOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer(false);

@RegisterExtension
private static final TestDropwizardAppExtension app;
Expand Down Expand Up @@ -134,21 +134,21 @@ void tracesCountForWorkspace() {
// Setup second workspace with traces, but leave created_at date set to today, so traces do not end up in the pool
var workspaceNameForToday = UUID.randomUUID().toString();
var workspaceIdForToday = UUID.randomUUID().toString();
setupTracesForWorkspace(workspaceNameForToday, workspaceIdForToday, okApikey);
var apikey = UUID.randomUUID().toString();

setupTracesForWorkspace(workspaceNameForToday, workspaceIdForToday, apikey);

try (var actualResponse = client.target(USAGE_RESOURCE_URL_TEMPLATE.formatted(baseURI))
.path("/workspace-trace-counts")
.request()
.header(HttpHeaders.AUTHORIZATION, okApikey)
.header(WORKSPACE_HEADER, workspaceName)
.get()) {

assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200);
assertThat(actualResponse.hasEntity()).isTrue();

var response = actualResponse.readEntity(TraceCountResponse.class);
assertThat(response.workspacesTracesCount().size()).isEqualTo(1);
assertThat(response.workspacesTracesCount().get(0))
assertThat(response.workspacesTracesCount()).hasSize(1);
assertThat(response.workspacesTracesCount().getFirst())
.isEqualTo(new TraceCountResponse.WorkspaceTraceCount(workspaceId, tracesCount));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3231,6 +3231,30 @@ void batch__whenCreateSpans__thenReturnNoContent() {
API_KEY);
}

@Test
void batch__whenSpansProjectNameIsNull__thenUserDefaultProjectAndReturnNoContent() {

String apiKey = UUID.randomUUID().toString();
String workspaceName = UUID.randomUUID().toString();
String workspaceId = UUID.randomUUID().toString();

mockTargetWorkspace(apiKey, workspaceName, workspaceId);

var expectedSpans = PodamFactoryUtils.manufacturePojoList(podamFactory, Span.class).stream()
.map(trace -> trace.toBuilder()
.projectName(null)
.endTime(null)
.usage(null)
.feedbackScores(null)
.build())
.toList();

batchCreateAndAssert(expectedSpans, apiKey, workspaceName);

getAndAssertPage(workspaceName, DEFAULT_PROJECT, List.of(), List.of(), expectedSpans.reversed(), List.of(),
apiKey);
}

@Test
void batch__whenSendingMultipleSpansWithSameId__thenReturn422() {
var expectedSpans = List.of(podamFactory.manufacturePojo(Span.class).toBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3434,12 +3434,11 @@ void batch__whenCreateTraces__thenReturnNoContent() {

var projectName = UUID.randomUUID().toString();

var projectId = createProject(projectName, TEST_WORKSPACE, API_KEY);
createProject(projectName, TEST_WORKSPACE, API_KEY);

var expectedTraces = IntStream.range(0, 1000)
.mapToObj(i -> factory.manufacturePojo(Trace.class).toBuilder()
.projectName(projectName)
.projectId(projectId)
.endTime(null)
.usage(null)
.feedbackScores(null)
Expand All @@ -3452,6 +3451,30 @@ void batch__whenCreateTraces__thenReturnNoContent() {
API_KEY);
}

@Test
void batch__whenTraceProjectNameIsNull__thenUserDefaultProjectAndReturnNoContent() {

String apiKey = UUID.randomUUID().toString();
String workspaceName = UUID.randomUUID().toString();
String workspaceId = UUID.randomUUID().toString();

mockTargetWorkspace(apiKey, workspaceName, workspaceId);

var expectedTraces = PodamFactoryUtils.manufacturePojoList(factory, Trace.class).stream()
.map(trace -> trace.toBuilder()
.projectName(null)
.endTime(null)
.usage(null)
.feedbackScores(null)
.build())
.toList();

batchCreateTracesAndAssert(expectedTraces, apiKey, workspaceName);

getAndAssertPage(workspaceName, DEFAULT_PROJECT, List.of(), List.of(), expectedTraces.reversed(), List.of(),
apiKey);
}

@Test
void batch__whenSendingMultipleTracesWithSameId__thenReturn422() {
var trace = factory.manufacturePojo(Trace.class).toBuilder()
Expand Down

0 comments on commit b4b57d1

Please sign in to comment.