Skip to content

Commit

Permalink
Merge branch 'main' into andrescrz/OPIK-52-find-experiment-by-partial…
Browse files Browse the repository at this point in the history
…-name-case-insensitive
  • Loading branch information
andrescrz authored Sep 4, 2024
2 parents 05128e5 + 22440b5 commit 13ece43
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 159 deletions.
5 changes: 3 additions & 2 deletions apps/opik-backend/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ health:
distributedLock:
lockTimeoutMS: ${DISTRIBUTED_LOCK_TIME_OUT:-500}

# For Redis
# If sentinelMode is true, masterName and nodes are required
bulkOperations:
size: ${BULK_OPERATION_SIZE:-200}

redis:
singleNodeUrl: ${REDIS_URL:-}

Expand Down
1 change: 1 addition & 0 deletions apps/opik-backend/lombok.config
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
config.stopBubbling = true
lombok.anyconstructor.addconstructorproperties = true
lombok.copyableAnnotations += jakarta.inject.Named
lombok.copyableAnnotations += ru.vyarus.dropwizard.guice.module.yaml.bind.Config
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
import com.comet.opik.api.ExperimentItem;
import com.comet.opik.api.FeedbackScore;
import com.comet.opik.api.ScoreSource;
import com.comet.opik.infrastructure.BulkConfig;
import com.comet.opik.infrastructure.db.TransactionTemplate;
import com.comet.opik.utils.AsyncUtils;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.collect.Lists;
import com.google.inject.ImplementedBy;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Statement;
import jakarta.inject.Inject;
Expand All @@ -25,11 +27,11 @@
import org.stringtemplate.v4.ST;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

import java.math.BigDecimal;
import java.time.Instant;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
Expand All @@ -39,6 +41,8 @@
import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux;
import static com.comet.opik.utils.AsyncUtils.makeFluxContextAware;
import static com.comet.opik.utils.AsyncUtils.makeMonoContextAware;
import static com.comet.opik.utils.TemplateUtils.QueryItem;
import static com.comet.opik.utils.TemplateUtils.getQueryItemPlaceHolder;
import static com.comet.opik.utils.ValidationUtils.CLICKHOUSE_FIXED_STRING_UUID_FIELD_NULL_VALUE;

@ImplementedBy(DatasetItemDAOImpl.class)
Expand Down Expand Up @@ -112,25 +116,29 @@ INSERT INTO dataset_items (
) AS created_by,
new.last_updated_by
FROM (
SELECT
:id AS id,
:datasetId AS dataset_id,
:source AS source,
:traceId AS trace_id,
:spanId AS span_id,
:input AS input,
:expectedOutput AS expected_output,
:metadata AS metadata,
now64(9) AS created_at,
:workspace_id AS workspace_id,
:createdBy AS created_by,
:lastUpdatedBy AS last_updated_by
<items:{item |
SELECT
:id<item.index> AS id,
:datasetId<item.index> AS dataset_id,
:source<item.index> AS source,
:traceId<item.index> AS trace_id,
:spanId<item.index> AS span_id,
:input<item.index> AS input,
:expectedOutput<item.index> AS expected_output,
:metadata<item.index> AS metadata,
now64(9) AS created_at,
:workspace_id<item.index> AS workspace_id,
:createdBy<item.index> AS created_by,
:lastUpdatedBy<item.index> AS last_updated_by
<if(item.hasNext)>
UNION ALL
<endif>
}>
) AS new
LEFT JOIN (
SELECT
*
FROM dataset_items
WHERE id = :id
ORDER BY last_updated_at DESC
LIMIT 1 BY id
) AS old
Expand Down Expand Up @@ -371,6 +379,7 @@ LEFT JOIN (
""";

private final @NonNull TransactionTemplate asyncTemplate;
private final @NonNull @Config("bulkOperations") BulkConfig bulkConfig;

@Override
public Mono<Long> save(@NonNull UUID datasetId, @NonNull List<DatasetItem> items) {
Expand All @@ -379,47 +388,49 @@ public Mono<Long> save(@NonNull UUID datasetId, @NonNull List<DatasetItem> items
return Mono.empty();
}

return inset(datasetId, items)
.retryWhen(AsyncUtils.handleConnectionError());
return inset(datasetId, items);
}

private Mono<Long> inset(UUID datasetId, List<DatasetItem> items) {
return asyncTemplate.nonTransaction(connection -> {
List<List<DatasetItem>> batches = Lists.partition(items, bulkConfig.getSize());

var statement = connection.createStatement(INSERT_DATASET_ITEM);

return mapAndInsert(datasetId, items, statement)
.flatMap(Result::getRowsUpdated)
.reduce(0L, Long::sum);
});
return Flux.fromIterable(batches)
.flatMapSequential(batch -> asyncTemplate.nonTransaction(connection -> mapAndInsert(datasetId, batch, connection)))
.reduce(0L, Long::sum);
}

private Flux<? extends Result> mapAndInsert(UUID datasetId, List<DatasetItem> items, Statement statement) {
return makeFluxContextAware((userName, workspaceName, workspaceId) -> {

for (Iterator<DatasetItem> iterator = items.iterator(); iterator.hasNext();) {
var item = iterator.next();

statement.bind("id", item.id())
.bind("datasetId", datasetId)
.bind("input", item.input().toString())
.bind("source", item.source().getValue())
.bind("traceId", getOrDefault(item.traceId()))
.bind("spanId", getOrDefault(item.spanId()))
.bind("expectedOutput", getOrDefault(item.expectedOutput()))
.bind("metadata", getOrDefault(item.metadata()))
.bind("workspace_id", workspaceId)
.bind("createdBy", userName)
.bind("lastUpdatedBy", userName);

if (iterator.hasNext()) {
statement.add();
}
}
private Mono<Long> mapAndInsert(UUID datasetId, List<DatasetItem> items, Connection connection) {

List<QueryItem> queryItems = getQueryItemPlaceHolder(items);

statement.fetchSize(items.size());
var template = new ST(INSERT_DATASET_ITEM)
.add("items", queryItems);

return Flux.from(statement.execute());
String sql = template.render();

var statement = connection.createStatement(sql);

return makeMonoContextAware((userName, workspaceName, workspaceId) -> {

int i = 0;
for (DatasetItem item : items) {
statement.bind("id" + i, item.id());
statement.bind("datasetId" + i, datasetId);
statement.bind("source" + i, item.source().getValue());
statement.bind("traceId" + i, getOrDefault(item.traceId()));
statement.bind("spanId" + i, getOrDefault(item.spanId()));
statement.bind("input" + i, getOrDefault(item.input()));
statement.bind("expectedOutput" + i, getOrDefault(item.expectedOutput()));
statement.bind("metadata" + i, getOrDefault(item.metadata()));
statement.bind("workspace_id" + i, workspaceId);
statement.bind("createdBy" + i,userName);
statement.bind("lastUpdatedBy" + i, userName);
i++;
}

return Flux.from(statement.execute())
.flatMap(Result::getRowsUpdated)
.reduce(0L, Long::sum);
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.comet.opik.domain;

import com.comet.opik.api.ExperimentItem;
import com.comet.opik.infrastructure.BulkConfig;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Result;
Expand All @@ -13,16 +15,22 @@
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.reactivestreams.Publisher;
import org.stringtemplate.v4.ST;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

import java.time.Instant;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.UUID;

import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux;
import static com.comet.opik.utils.AsyncUtils.makeFluxContextAware;
import static com.comet.opik.utils.AsyncUtils.makeMonoContextAware;
import static com.comet.opik.utils.TemplateUtils.QueryItem;
import static com.comet.opik.utils.TemplateUtils.getQueryItemPlaceHolder;

@Singleton
@RequiredArgsConstructor(onConstructor_ = @Inject)
Expand Down Expand Up @@ -65,20 +73,24 @@ INSERT INTO experiment_items (
new.created_by,
new.last_updated_by
FROM (
SELECT
:id AS id,
:experiment_id AS experiment_id,
:dataset_item_id AS dataset_item_id,
:trace_id AS trace_id,
:workspace_id AS workspace_id,
:created_by AS created_by,
:last_updated_by AS last_updated_by
) AS new
<items:{item |
SELECT
:id<item.index> AS id,
:experiment_id<item.index> AS experiment_id,
:dataset_item_id<item.index> AS dataset_item_id,
:trace_id<item.index> AS trace_id,
:workspace_id<item.index> AS workspace_id,
:created_by<item.index> AS created_by,
:last_updated_by<item.index> AS last_updated_by
<if(item.hasNext)>
UNION ALL
<endif>
}>
) AS new
LEFT JOIN (
SELECT
id, workspace_id
FROM experiment_items
WHERE id = :id
ORDER BY last_updated_at DESC
LIMIT 1 BY id
) AS old
Expand Down Expand Up @@ -119,6 +131,7 @@ LEFT JOIN (
""";

private final @NonNull ConnectionFactory connectionFactory;
private final @NonNull @Config("bulkOperations") BulkConfig bulkConfig;

public Flux<ExperimentSummary> findExperimentSummaryByDatasetIds(Collection<UUID> datasetIds) {

Expand All @@ -143,39 +156,49 @@ public Flux<ExperimentSummary> findExperimentSummaryByDatasetIds(Collection<UUID
public Mono<Long> insert(@NonNull Set<ExperimentItem> experimentItems) {
Preconditions.checkArgument(CollectionUtils.isNotEmpty(experimentItems),
"Argument 'experimentItems' must not be empty");
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> insert(experimentItems, connection))

log.info("Inserting experiment items, count '{}'", experimentItems.size());

List<List<ExperimentItem>> batches = Lists.partition(List.copyOf(experimentItems), bulkConfig.getSize());

return Flux.fromIterable(batches)
.flatMapSequential(batch -> Mono.from(connectionFactory.create())
.flatMap(connection -> insert(experimentItems, connection)))
.reduce(0L, Long::sum);
}

private Flux<Long> insert(Set<ExperimentItem> experimentItems, Connection connection) {
private Mono<Long> insert(Collection<ExperimentItem> experimentItems, Connection connection) {

log.info("Inserting experiment items, count '{}'", experimentItems.size());
var statement = connection.createStatement(INSERT);

return makeFluxContextAware((userName, workspaceName, workspaceId) -> {

for (var iterator = experimentItems.iterator(); iterator.hasNext();) {
var item = iterator.next();
statement.bind("id", item.id())
.bind("experiment_id", item.experimentId())
.bind("dataset_item_id", item.datasetItemId())
.bind("trace_id", item.traceId())
.bind("workspace_id", workspaceId)
.bind("created_by", userName)
.bind("last_updated_by", userName);

if (iterator.hasNext()) {
statement.add();
}
}
List<QueryItem> queryItems = getQueryItemPlaceHolder(experimentItems);

var template = new ST(INSERT)
.add("items", queryItems);

String sql = template.render();

var statement = connection.createStatement(sql);

statement.fetchSize(experimentItems.size());
return makeMonoContextAware((userName, workspaceName, workspaceId) -> {

return Flux.from(statement.execute()).flatMap(Result::getRowsUpdated);
int index = 0;
for (ExperimentItem item : experimentItems) {
statement.bind("id" + index, item.id());
statement.bind("experiment_id" + index, item.experimentId());
statement.bind("dataset_item_id" + index, item.datasetItemId());
statement.bind("trace_id" + index, item.traceId());
statement.bind("workspace_id" + index, workspaceId);
statement.bind("created_by" + index, userName);
statement.bind("last_updated_by" + index, userName);
index++;
}

return Flux.from(statement.execute())
.flatMap(Result::getRowsUpdated)
.reduce(0L, Long::sum);
});
}


private Publisher<ExperimentItem> mapToExperimentItem(Result result) {
return result.map((row, rowMetadata) -> ExperimentItem.builder()
.id(row.get("id", UUID.class))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.comet.opik.infrastructure;


import com.fasterxml.jackson.annotation.JsonProperty;
import jakarta.validation.Valid;
import lombok.Data;

@Data
public class BulkConfig {

@Valid
@JsonProperty
private Integer size;

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ public class OpikConfiguration extends Configuration {
@Valid
@NotNull @JsonProperty
private DistributedLockConfig distributedLock = new DistributedLockConfig();

@Valid
@NotNull @JsonProperty
private BulkConfig bulkOperations = new BulkConfig();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.comet.opik.utils;

import java.util.Collection;
import java.util.List;
import java.util.stream.IntStream;

public class TemplateUtils {

public static class QueryItem {
public final int index;
public final boolean hasNext;

public QueryItem(int index, boolean hasNext) {
this.index = index;
this.hasNext = hasNext;
}
}

public static List<QueryItem> getQueryItemPlaceHolder(Collection<?> items) {

if (items == null || items.isEmpty()) {
return List.of();
}

return IntStream.range(0, items.size())
.mapToObj(i -> new QueryItem(i, i < items.size() - 1))
.toList();
}
}
Loading

0 comments on commit 13ece43

Please sign in to comment.