Skip to content

Commit

Permalink
[OPIK-261] Feedback score batch fix for closed connection (#395)
Browse files Browse the repository at this point in the history
* [OPIK-261] Feedback score batch fix for closed connection

* Fix validation

* Address PR feedback

* Remove locks
  • Loading branch information
thiagohora authored Oct 16, 2024
1 parent 9d8ed61 commit 2aa5683
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 373 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,7 @@ public interface FeedbackScoreMapper {

@Mapping(target = "id", source = "entityId")
FeedbackScoreBatchItem toFeedbackScoreBatchItem(UUID entityId, String projectName, FeedbackScore feedbackScore);

@Mapping(target = "id", source = "entityId")
FeedbackScoreBatchItem toFeedbackScore(UUID entityId, UUID projectId, FeedbackScore score);
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
package com.comet.opik.domain;

import com.clickhouse.client.ClickHouseException;
import com.comet.opik.api.FeedbackScore;
import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.Project;
import com.comet.opik.api.error.ErrorMessage;
import com.comet.opik.api.error.IdentifierMismatchException;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.infrastructure.db.TransactionTemplateAsync;
import com.comet.opik.infrastructure.lock.LockService;
import com.comet.opik.utils.WorkspaceUtils;
import com.google.inject.ImplementedBy;
import com.google.inject.Singleton;
Expand All @@ -22,10 +18,12 @@
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate;

import java.sql.SQLIntegrityConstraintViolationException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;

Expand Down Expand Up @@ -53,43 +51,48 @@ public interface FeedbackScoreService {
@RequiredArgsConstructor(onConstructor_ = @Inject)
class FeedbackScoreServiceImpl implements FeedbackScoreService {

private static final String SPAN_SCORE_KEY = "span-score-%s";
private static final String TRACE_SCORE_KEY = "trace-score-%s";

private final @NonNull FeedbackScoreDAO dao;
private final @NonNull ru.vyarus.guicey.jdbi3.tx.TransactionTemplate syncTemplate;
private final @NonNull TransactionTemplateAsync asyncTemplate;
private final @NonNull TransactionTemplate syncTemplate;
private final @NonNull SpanDAO spanDAO;
private final @NonNull TraceDAO traceDAO;
private final @NonNull IdGenerator idGenerator;
private final @NonNull LockService lockService;

record ProjectDto(Project project, List<FeedbackScoreBatchItem> scores) {
}

@Override
public Mono<Void> scoreTrace(@NonNull UUID traceId, @NonNull FeedbackScore score) {
return lockService.executeWithLock(
new LockService.Lock(traceId, TRACE_SCORE_KEY.formatted(score.name())),
Mono.defer(() -> asyncTemplate
.nonTransaction(connection -> dao.scoreEntity(EntityType.TRACE, traceId, score, connection))))
.flatMap(this::extractResult)
.switchIfEmpty(Mono.defer(() -> Mono.error(failWithTraceNotFound(traceId))))
.then();
return traceDAO.getProjectIdFromTraces(Set.of(traceId))
.flatMap(traceProjectIdMap -> {

if (traceProjectIdMap.get(traceId) == null) {
return Mono.error(failWithTraceNotFound(traceId));
}

return dao.scoreEntity(EntityType.TRACE, traceId, score, traceProjectIdMap)
.flatMap(this::extractResult)
.then();
});
}

@Override
public Mono<Void> scoreSpan(@NonNull UUID spanId, @NonNull FeedbackScore score) {
return lockService.executeWithLock(
new LockService.Lock(spanId, SPAN_SCORE_KEY.formatted(score.name())),
Mono.defer(() -> asyncTemplate
.nonTransaction(connection -> dao.scoreEntity(EntityType.SPAN, spanId, score, connection))))
.flatMap(this::extractResult)
.switchIfEmpty(Mono.defer(() -> Mono.error(failWithSpanNotFound(spanId))))
.then();

return spanDAO.getProjectIdFromSpans(Set.of(spanId))
.flatMap(spanProjectIdMap -> {

if (spanProjectIdMap.get(spanId) == null) {
return Mono.error(failWithSpanNotFound(spanId));
}

return dao.scoreEntity(EntityType.SPAN, spanId, score, spanProjectIdMap)
.flatMap(this::extractResult)
.then();
});
}

@Override
public Mono<Void> scoreBatchOfSpans(@NonNull List<FeedbackScoreBatchItem> scores) {

return processScoreBatch(EntityType.SPAN, scores);
}

Expand Down Expand Up @@ -120,7 +123,6 @@ private Mono<Void> processScoreBatch(EntityType entityType, List<FeedbackScoreBa
.map(this::groupByName)
.map(projectMap -> mergeProjectsAndScores(projectMap, scoresPerProject))
.flatMap(projects -> processScoreBatch(entityType, projects, scores.size())) // score all scores
.onErrorResume(e -> tryHandlingException(entityType, e))
.then();
}

Expand All @@ -144,36 +146,9 @@ private Mono<Void> checkIfNeededToCreateProjectsWithContext(String workspaceId,
.then();
}

private Mono<Long> tryHandlingException(EntityType entityType, Throwable e) {
return switch (e) {
case ClickHouseException clickHouseException -> {
//TODO: Find a better way to handle this.
// This is a workaround to handle the case when project_id from score and project_name from project does not match.
if (clickHouseException.getMessage().contains("TOO_LARGE_STRING_SIZE") &&
clickHouseException.getMessage().contains("_CAST(project_id, FixedString(36))")) {
yield failWithConflict("project_name from score and project_id from %s does not match"
.formatted(entityType.getType()));
}
yield Mono.error(e);
}
default -> Mono.error(e);
};
}

private Mono<Long> failWithConflict(String message) {
return Mono.error(new IdentifierMismatchException(new ErrorMessage(List.of(message))));
}

private Mono<Long> processScoreBatch(EntityType entityType, List<ProjectDto> projects, int actualBatchSize) {
return Flux.fromIterable(projects)
.flatMap(projectDto -> {
var lock = new LockService.Lock(projectDto.project().id(), "%s-scores-batch".formatted(entityType));

Mono<Long> batchProcess = Mono.defer(() -> asyncTemplate.nonTransaction(
connection -> dao.scoreBatchOf(entityType, projectDto.scores(), connection)));

return lockService.executeWithLock(lock, batchProcess);
})
.flatMap(projectDto -> dao.scoreBatchOf(entityType, projectDto.scores()))
.reduce(0L, Long::sum)
.flatMap(rowsUpdated -> rowsUpdated == actualBatchSize ? Mono.just(rowsUpdated) : Mono.empty())
.switchIfEmpty(Mono.defer(() -> failWithNotFound("Error while processing scores batch")));
Expand Down Expand Up @@ -245,18 +220,12 @@ private void checkIfNeededToCreateProjects(Map<String, List<FeedbackScoreBatchIt

@Override
public Mono<Void> deleteSpanScore(UUID id, String name) {
return lockService.executeWithLock(
new LockService.Lock(id, SPAN_SCORE_KEY.formatted(name)),
Mono.defer(() -> asyncTemplate
.nonTransaction(connection -> dao.deleteScoreFrom(EntityType.SPAN, id, name, connection))));
return dao.deleteScoreFrom(EntityType.SPAN, id, name);
}

@Override
public Mono<Void> deleteTraceScore(UUID id, String name) {
return lockService.executeWithLock(
new LockService.Lock(id, TRACE_SCORE_KEY.formatted(name)),
Mono.defer(() -> asyncTemplate
.nonTransaction(connection -> dao.deleteScoreFrom(EntityType.TRACE, id, name, connection))));
return dao.deleteScoreFrom(EntityType.TRACE, id, name);
}

private Mono<Long> failWithNotFound(String errorMessage) {
Expand Down
43 changes: 36 additions & 7 deletions apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,17 @@ AND id in (
;
""";

public static final String SELECT_PROJECT_ID_FROM_SPANS = """
SELECT
id,
project_id
FROM spans
WHERE id IN :ids
AND workspace_id = :workspace_id
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
""";

private final @NonNull ConnectionFactory connectionFactory;
private final @NonNull FeedbackScoreDAO feedbackScoreDAO;
private final @NonNull FilterQueryBuilder filterQueryBuilder;
Expand Down Expand Up @@ -727,9 +738,7 @@ public Mono<Span> getById(@NonNull UUID id) {
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> getById(id, connection))
.flatMap(this::mapToDto)
.flatMap(span -> Mono.from(connectionFactory.create())
.flatMap(connection -> enhanceWithFeedbackScores(List.of(span), connection)
.map(List::getFirst)))
.flatMap(span -> enhanceWithFeedbackScores(List.of(span)).map(List::getFirst))
.singleOrEmpty();
}

Expand Down Expand Up @@ -811,17 +820,16 @@ private Mono<Span.SpanPage> find(int page, int size, SpanSearchCriteria spanSear
.flatMapMany(connection -> find(page, size, spanSearchCriteria, connection))
.flatMap(this::mapToDto)
.collectList()
.flatMap(spans -> Mono.from(connectionFactory.create())
.flatMap(connection -> enhanceWithFeedbackScores(spans, connection)))
.flatMap(this::enhanceWithFeedbackScores)
.map(spans -> new Span.SpanPage(page, spans.size(), total, spans));
}

private Mono<List<Span>> enhanceWithFeedbackScores(List<Span> spans, Connection connection) {
private Mono<List<Span>> enhanceWithFeedbackScores(List<Span> spans) {
List<UUID> spanIds = spans.stream().map(Span::id).toList();

Segment segment = startSegment("spans", "Clickhouse", "enhance_with_feedback_scores");

return feedbackScoreDAO.getScores(EntityType.SPAN, spanIds, connection)
return feedbackScoreDAO.getScores(EntityType.SPAN, spanIds)
.map(scoresMap -> spans.stream()
.map(span -> span.toBuilder().feedbackScores(scoresMap.get(span.id())).build())
.toList())
Expand Down Expand Up @@ -912,4 +920,25 @@ public Mono<List<WorkspaceAndResourceId>> getSpanWorkspace(@NonNull Set<UUID> sp
row.get("id", UUID.class))))
.collectList();
}

@WithSpan
public Mono<Map<UUID, UUID>> getProjectIdFromSpans(@NonNull Set<UUID> spanIds) {

if (spanIds.isEmpty()) {
return Mono.just(Map.of());
}

return Mono.from(connectionFactory.create())
.flatMapMany(connection -> {

var statement = connection.createStatement(SELECT_PROJECT_ID_FROM_SPANS)
.bind("ids", spanIds.toArray(UUID[]::new));

return makeFluxContextAware(bindWorkspaceIdToFlux(statement));
})
.flatMap(result -> result.map((row, rowMetadata) -> Map.entry(
row.get("id", UUID.class),
row.get("project_id", UUID.class))))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.comet.opik.api.TraceUpdate;
import com.comet.opik.domain.filter.FilterQueryBuilder;
import com.comet.opik.domain.filter.FilterStrategy;
import com.comet.opik.infrastructure.db.TransactionTemplateAsync;
import com.comet.opik.utils.JsonUtils;
import com.comet.opik.utils.TemplateUtils;
import com.fasterxml.jackson.databind.JsonNode;
Expand Down Expand Up @@ -71,8 +72,9 @@ interface TraceDAO {

Flux<WorkspaceTraceCount> countTracesPerWorkspace(Connection connection);

Mono<Map<UUID, Instant>> getLastUpdatedTraceAt(@NonNull Set<UUID> projectIds, @NonNull String workspaceId,
@NonNull Connection connection);
Mono<Map<UUID, Instant>> getLastUpdatedTraceAt(Set<UUID> projectIds, String workspaceId, Connection connection);

Mono<Map<UUID, UUID>> getProjectIdFromTraces(Set<UUID> traceIds);
}

@Slf4j
Expand Down Expand Up @@ -509,9 +511,21 @@ LEFT JOIN (
GROUP BY t.project_id
;
""";
private static final String SELECT_PROJECT_ID_FROM_TRACES = """
SELECT
id,
project_id
FROM traces
WHERE id IN :ids
AND workspace_id = :workspace_id
ORDER BY id DESC, last_updated_at DESC
LIMIT 1 BY id
;
""";

private final @NonNull FeedbackScoreDAO feedbackScoreDAO;
private final @NonNull FilterQueryBuilder filterQueryBuilder;
private final @NonNull TransactionTemplateAsync asyncTemplate;

@Override
@WithSpan
Expand Down Expand Up @@ -677,7 +691,7 @@ public Mono<Void> delete(Set<UUID> ids, @NonNull Connection connection) {
public Mono<Trace> findById(@NonNull UUID id, @NonNull Connection connection) {
return getById(id, connection)
.flatMap(this::mapToDto)
.flatMap(trace -> enhanceWithFeedbackLogs(List.of(trace), connection))
.flatMap(trace -> enhanceWithFeedbackLogs(List.of(trace)))
.flatMap(traces -> Mono.justOrEmpty(traces.stream().findFirst()))
.singleOrEmpty();
}
Expand Down Expand Up @@ -722,7 +736,7 @@ public Mono<TracePage> find(
.flatMap(total -> getTracesByProjectId(size, page, traceSearchCriteria, connection) //Get count then pagination
.flatMapMany(this::mapToDto)
.collectList()
.flatMap(traces -> enhanceWithFeedbackLogs(traces, connection))
.flatMap(this::enhanceWithFeedbackLogs)
.map(traces -> new TracePage(page, traces.size(), total, traces)));
}

Expand Down Expand Up @@ -750,12 +764,12 @@ public Mono<Void> partialInsert(
.then();
}

private Mono<List<Trace>> enhanceWithFeedbackLogs(List<Trace> traces, Connection connection) {
private Mono<List<Trace>> enhanceWithFeedbackLogs(List<Trace> traces) {
List<UUID> traceIds = traces.stream().map(Trace::id).toList();

Segment segment = startSegment("traces", "Clickhouse", "enhanceWithFeedbackLogs");

return feedbackScoreDAO.getScores(EntityType.TRACE, traceIds, connection)
return feedbackScoreDAO.getScores(EntityType.TRACE, traceIds)
.map(logsMap -> traces.stream()
.map(trace -> trace.toBuilder().feedbackScores(logsMap.get(trace.id())).build())
.toList())
Expand Down Expand Up @@ -926,4 +940,23 @@ public Mono<Map<UUID, Instant>> getLastUpdatedTraceAt(
}
});
}

@Override
public Mono<Map<UUID, UUID>> getProjectIdFromTraces(@NonNull Set<UUID> traceIds) {

if (traceIds.isEmpty()) {
return Mono.just(Map.of());
}

return asyncTemplate.nonTransaction(connection -> {
var statement = connection.createStatement(SELECT_PROJECT_ID_FROM_TRACES)
.bind("ids", traceIds.toArray(UUID[]::new));

return makeFluxContextAware(bindWorkspaceIdToFlux(statement))
.flatMap(result -> result.map((row, rowMetadata) -> Map.entry(
row.get("id", UUID.class),
row.get("project_id", UUID.class))))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,10 @@ public Mono<Void> delete(@NonNull UUID id) {
log.info("Deleting trace by id '{}'", id);
return lockService.executeWithLock(
new LockService.Lock(id, TRACE_KEY),
Mono.defer(() -> template
.nonTransaction(
connection -> feedbackScoreDAO.deleteByEntityId(EntityType.TRACE, id, connection))
Mono.defer(() -> feedbackScoreDAO.deleteByEntityId(EntityType.TRACE, id))
.then(Mono.defer(
() -> template.nonTransaction(connection -> spanDAO.deleteByTraceId(id, connection))))
.then(Mono.defer(() -> template.nonTransaction(connection -> dao.delete(id, connection))))));
.then(Mono.defer(() -> template.nonTransaction(connection -> dao.delete(id, connection)))));
}

@Override
Expand All @@ -292,7 +290,7 @@ public Mono<Void> delete(Set<UUID> ids) {
Preconditions.checkArgument(CollectionUtils.isNotEmpty(ids), "Argument 'ids' must not be empty");
log.info("Deleting traces, count '{}'", ids.size());
return template
.nonTransaction(connection -> feedbackScoreDAO.deleteByEntityIds(EntityType.TRACE, ids, connection))
.nonTransaction(connection -> feedbackScoreDAO.deleteByEntityIds(EntityType.TRACE, ids))
.then(Mono
.defer(() -> template.nonTransaction(connection -> spanDAO.deleteByTraceIds(ids, connection))))
.then(Mono.defer(() -> template.nonTransaction(connection -> dao.delete(ids, connection))));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ public interface TransactionTemplateAsync {
TxConfig WRITE = new TxConfig().readOnly(false);
TxConfig READ_ONLY = new TxConfig().readOnly(true);

static TransactionTemplateAsync create(ConnectionFactory connectionFactory) {
return new TransactionTemplateAsyncImpl(connectionFactory);
}

interface TransactionCallback<T> {
Mono<T> execute(Connection handler);
}
Expand Down
Loading

0 comments on commit 2aa5683

Please sign in to comment.