Skip to content

Commit

Permalink
[compute-ai-embeddings] Compute Embeddings Agent: execute calls async…
Browse files Browse the repository at this point in the history
…hronously and in batches while preserving per-key ordering (#457)
  • Loading branch information
eolivelli authored Sep 21, 2023
1 parent 68b9cdc commit 312f6dc
Show file tree
Hide file tree
Showing 13 changed files with 751 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,37 @@
*/
package com.datastax.oss.streaming.ai;

import ai.langstream.api.util.OrderedAsyncBatchExecutor;
import com.datastax.oss.streaming.ai.embeddings.EmbeddingsService;
import com.datastax.oss.streaming.ai.model.JsonRecord;
import com.datastax.oss.streaming.ai.util.TransformFunctionUtil;
import com.samskivert.mustache.Mustache;
import com.samskivert.mustache.Template;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import lombok.extern.slf4j.Slf4j;
import org.apache.avro.Schema;

/**
* Compute AI Embeddings from a template filled with the received message fields and metadata and
* put the value into a new or existing field.
*/
@Slf4j
public class ComputeAIEmbeddingsStep implements TransformStep {

static final Random RANDOM = new Random();

private final Template template;
private final String embeddingsFieldName;
private final EmbeddingsService embeddingsService;

private final TransformFunctionUtil.BatchExecutor<RecordHolder> batchExecutor;
private final OrderedAsyncBatchExecutor<RecordHolder> batchExecutor;

private final ScheduledExecutorService executorService;
private final Map<org.apache.avro.Schema, org.apache.avro.Schema> avroValueSchemaCache =
Expand All @@ -53,59 +59,93 @@ public ComputeAIEmbeddingsStep(
String embeddingsFieldName,
int batchSize,
long flushInterval,
int concurrency,
EmbeddingsService embeddingsService) {
this.template = Mustache.compiler().compile(text);
this.embeddingsFieldName = embeddingsFieldName;
this.embeddingsService = embeddingsService;
this.executorService =
flushInterval > 0 ? Executors.newSingleThreadScheduledExecutor() : null;
int numBuckets = concurrency > 0 ? concurrency : 1;
this.batchExecutor =
new TransformFunctionUtil.BatchExecutor<>(
batchSize, this::processBatch, flushInterval, executorService);
new OrderedAsyncBatchExecutor<>(
batchSize,
this::processBatch,
flushInterval,
numBuckets,
ComputeAIEmbeddingsStep::computeHashForRecord,
executorService);
}

private static int computeHashForRecord(RecordHolder record) {
Object key = record.transformContext.getKeyObject();
if (key != null) {
return Objects.hashCode(key);
} else {
return RANDOM.nextInt();
}
}

@Override
public void start() throws Exception {
batchExecutor.start();
}

private void processBatch(List<RecordHolder> records) {
private void processBatch(List<RecordHolder> records, CompletableFuture<?> completionHandle) {

// prepare batch API call
List<String> texts = new ArrayList<>();
for (RecordHolder holder : records) {
TransformContext transformContext = holder.transformContext();
JsonRecord jsonRecord = transformContext.toJsonRecord();
String text = template.execute(jsonRecord);
texts.add(text);

try {
for (RecordHolder holder : records) {
TransformContext transformContext = holder.transformContext();
JsonRecord jsonRecord = transformContext.toJsonRecord();
String text = template.execute(jsonRecord);
texts.add(text);
}
} catch (Throwable error) {
// we cannot fail only some records, because we must keep the order
log.error(
"At least one error failed the conversion to JSON, failing the whole batch",
error);
errorForAll(records, error);
completionHandle.complete(null);
return;
}

CompletableFuture<List<List<Double>>> embeddings =
embeddingsService.computeEmbeddings(texts);

embeddings.whenComplete(
(result, error) -> {
if (error != null) {
for (int i = 0; i < records.size(); i++) {
RecordHolder holder = records.get(i);
holder.handle.completeExceptionally(error);
}
return;
}

for (int i = 0; i < records.size(); i++) {
RecordHolder holder = records.get(i);
TransformContext transformContext = holder.transformContext();
List<Double> embeddingsForText = result.get(i);
transformContext.setResultField(
embeddingsForText,
embeddingsFieldName,
Schema.createArray(Schema.create(Schema.Type.DOUBLE)),
avroKeySchemaCache,
avroValueSchemaCache);
holder.handle().complete(null);
}
});
embeddings
.whenComplete(
(result, error) -> {
if (error != null) {
errorForAll(records, error);
} else {
for (int i = 0; i < records.size(); i++) {
RecordHolder holder = records.get(i);
TransformContext transformContext = holder.transformContext();
List<Double> embeddingsForText = result.get(i);
transformContext.setResultField(
embeddingsForText,
embeddingsFieldName,
Schema.createArray(Schema.create(Schema.Type.DOUBLE)),
avroKeySchemaCache,
avroValueSchemaCache);
holder.handle().complete(null);
}
}
})
.whenComplete(
(a, b) -> {
completionHandle.complete(null);
});
}

private static void errorForAll(List<RecordHolder> records, Throwable error) {
for (RecordHolder holder : records) {
holder.handle.completeExceptionally(error);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,15 @@ public OpenAIEmbeddingsService(OpenAIAsyncClient openAIClient, String model) {
public CompletableFuture<List<List<Double>>> computeEmbeddings(List<String> texts) {
try {
EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(texts);
CompletableFuture<List<List<Double>>> result =
openAIClient
.getEmbeddings(model, embeddingsOptions)
.toFuture()
.thenApply(
embeddings ->
embeddings.getData().stream()
.map(embedding -> embedding.getEmbedding())
.collect(Collectors.toList()));
return openAIClient
.getEmbeddings(model, embeddingsOptions)
.toFuture()
.thenApply(
embeddings ->
embeddings.getData().stream()
.map(embedding -> embedding.getEmbedding())
.collect(Collectors.toList()));

// we need to wait, in order to guarantee ordering
// TODO: we should not wait, but instead use an ordered executor (per key)
try {
result.join();
} catch (Throwable err) {
log.error("Cannot compute embeddings", err);
}
return result;
} catch (RuntimeException err) {
log.error("Cannot compute embeddings", err);
return CompletableFuture.failedFuture(err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public class ComputeAIEmbeddingsConfig extends StepConfig {
@JsonProperty("batch-size")
private int batchSize = 10;

@JsonProperty("concurrency")
private int concurrency = 4;

// we disable flushing by default in order to avoid latency spikes
// you should enable this feature in the case of background processing
@JsonProperty("flush-interval")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,6 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Predicate;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
Expand Down Expand Up @@ -310,6 +306,7 @@ public static TransformStep newComputeAIEmbeddings(
config.getEmbeddingsFieldName(),
config.getBatchSize(),
config.getFlushInterval(),
config.getConcurrency(),
embeddingsService);
}

Expand Down Expand Up @@ -569,78 +566,4 @@ public X509Certificate[] getAcceptedIssuers() {
}
}
}

/**
* Aggregate records in batches, depending on a batch size and a maximum idle time.
*
* @param <T>
*/
public static class BatchExecutor<T> {
private final int batchSize;
private List<T> batch;
private long flushInterval;
private ScheduledExecutorService scheduledExecutorService;

private ScheduledFuture<?> scheduledFuture;

private final Consumer<List<T>> processor;

public BatchExecutor(
int batchSize,
Consumer<List<T>> processor,
long maxIdleTime,
ScheduledExecutorService scheduledExecutorService) {
this.batchSize = batchSize;
this.batch = new ArrayList<>(batchSize);
this.processor = processor;
this.flushInterval = maxIdleTime;
this.scheduledExecutorService = scheduledExecutorService;
}

public void start() {
if (flushInterval > 0) {
scheduledFuture =
scheduledExecutorService.scheduleWithFixedDelay(
this::flush, flushInterval, flushInterval, TimeUnit.MILLISECONDS);
}
}

public void stop() {
if (scheduledFuture != null) {
scheduledFuture.cancel(false);
}
flush();
}

private void flush() {
List<T> batchToProcess = null;
synchronized (this) {
if (batch.isEmpty()) {
return;
}
if (!batch.isEmpty()) {
batchToProcess = batch;
batch = new ArrayList<>(batchSize);
}
}
// execute the processor our of the synchronized block
processor.accept(batchToProcess);
}

public void add(T t) {
List<T> batchToProcess = null;
synchronized (this) {
batch.add(t);
if (batch.size() >= batchSize || flushInterval <= 0) {
batchToProcess = batch;
batch = new ArrayList<>(batchSize);
}
}

// execute the processor our of the synchronized block
if (batchToProcess != null) {
processor.accept(batchToProcess);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ void testAvro() throws Exception {
"value.newField",
1,
500,
1,
mockService);

Record<?> outputRecord = Utils.process(record, step);
Expand All @@ -84,7 +85,7 @@ void testKeyValueAvro() throws Exception {
mockService.setEmbeddingsForText("key1", expectedEmbeddings);
ComputeAIEmbeddingsStep step =
new ComputeAIEmbeddingsStep(
"{{ key.keyField1 }}", "value.newField", 1, 500, mockService);
"{{ key.keyField1 }}", "value.newField", 1, 500, 1, mockService);

Record<?> outputRecord = Utils.process(Utils.createTestAvroKeyValueRecord(), step);
KeyValueSchema<?, ?> messageSchema = (KeyValueSchema<?, ?>) outputRecord.getSchema();
Expand Down Expand Up @@ -123,6 +124,7 @@ void testJson() throws Exception {
"value.newField",
1,
500,
1,
mockService);

Record<?> outputRecord = Utils.process(record, step);
Expand Down
Loading

0 comments on commit 312f6dc

Please sign in to comment.