diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeAIEmbeddingsStep.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeAIEmbeddingsStep.java index 3ece38483..8aacc6f44 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeAIEmbeddingsStep.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/ComputeAIEmbeddingsStep.java @@ -15,31 +15,37 @@ */ package com.datastax.oss.streaming.ai; +import ai.langstream.api.util.OrderedAsyncBatchExecutor; import com.datastax.oss.streaming.ai.embeddings.EmbeddingsService; import com.datastax.oss.streaming.ai.model.JsonRecord; -import com.datastax.oss.streaming.ai.util.TransformFunctionUtil; import com.samskivert.mustache.Mustache; import com.samskivert.mustache.Template; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Random; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; +import lombok.extern.slf4j.Slf4j; import org.apache.avro.Schema; /** * Compute AI Embeddings from a template filled with the received message fields and metadata and * put the value into a new or existing field. */ +@Slf4j public class ComputeAIEmbeddingsStep implements TransformStep { + static final Random RANDOM = new Random(); + private final Template template; private final String embeddingsFieldName; private final EmbeddingsService embeddingsService; - private final TransformFunctionUtil.BatchExecutor batchExecutor; + private final OrderedAsyncBatchExecutor batchExecutor; private final ScheduledExecutorService executorService; private final Map avroValueSchemaCache = @@ -53,15 +59,31 @@ public ComputeAIEmbeddingsStep( String embeddingsFieldName, int batchSize, long flushInterval, + int concurrency, EmbeddingsService embeddingsService) { this.template = Mustache.compiler().compile(text); this.embeddingsFieldName = embeddingsFieldName; this.embeddingsService = embeddingsService; this.executorService = flushInterval > 0 ? Executors.newSingleThreadScheduledExecutor() : null; + int numBuckets = concurrency > 0 ? concurrency : 1; this.batchExecutor = - new TransformFunctionUtil.BatchExecutor<>( - batchSize, this::processBatch, flushInterval, executorService); + new OrderedAsyncBatchExecutor<>( + batchSize, + this::processBatch, + flushInterval, + numBuckets, + ComputeAIEmbeddingsStep::computeHashForRecord, + executorService); + } + + private static int computeHashForRecord(RecordHolder record) { + Object key = record.transformContext.getKeyObject(); + if (key != null) { + return Objects.hashCode(key); + } else { + return RANDOM.nextInt(); + } } @Override @@ -69,43 +91,61 @@ public void start() throws Exception { batchExecutor.start(); } - private void processBatch(List records) { + private void processBatch(List records, CompletableFuture completionHandle) { // prepare batch API call List texts = new ArrayList<>(); - for (RecordHolder holder : records) { - TransformContext transformContext = holder.transformContext(); - JsonRecord jsonRecord = transformContext.toJsonRecord(); - String text = template.execute(jsonRecord); - texts.add(text); + + try { + for (RecordHolder holder : records) { + TransformContext transformContext = holder.transformContext(); + JsonRecord jsonRecord = transformContext.toJsonRecord(); + String text = template.execute(jsonRecord); + texts.add(text); + } + } catch (Throwable error) { + // we cannot fail only some records, because we must keep the order + log.error( + "At least one error failed the conversion to JSON, failing the whole batch", + error); + errorForAll(records, error); + completionHandle.complete(null); + return; } CompletableFuture>> embeddings = embeddingsService.computeEmbeddings(texts); - embeddings.whenComplete( - (result, error) -> { - if (error != null) { - for (int i = 0; i < records.size(); i++) { - RecordHolder holder = records.get(i); - holder.handle.completeExceptionally(error); - } - return; - } - - for (int i = 0; i < records.size(); i++) { - RecordHolder holder = records.get(i); - TransformContext transformContext = holder.transformContext(); - List embeddingsForText = result.get(i); - transformContext.setResultField( - embeddingsForText, - embeddingsFieldName, - Schema.createArray(Schema.create(Schema.Type.DOUBLE)), - avroKeySchemaCache, - avroValueSchemaCache); - holder.handle().complete(null); - } - }); + embeddings + .whenComplete( + (result, error) -> { + if (error != null) { + errorForAll(records, error); + } else { + for (int i = 0; i < records.size(); i++) { + RecordHolder holder = records.get(i); + TransformContext transformContext = holder.transformContext(); + List embeddingsForText = result.get(i); + transformContext.setResultField( + embeddingsForText, + embeddingsFieldName, + Schema.createArray(Schema.create(Schema.Type.DOUBLE)), + avroKeySchemaCache, + avroValueSchemaCache); + holder.handle().complete(null); + } + } + }) + .whenComplete( + (a, b) -> { + completionHandle.complete(null); + }); + } + + private static void errorForAll(List records, Throwable error) { + for (RecordHolder holder : records) { + holder.handle.completeExceptionally(error); + } } @Override diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/embeddings/OpenAIEmbeddingsService.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/embeddings/OpenAIEmbeddingsService.java index 1e6dfdd10..89a13d53f 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/embeddings/OpenAIEmbeddingsService.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/embeddings/OpenAIEmbeddingsService.java @@ -37,24 +37,15 @@ public OpenAIEmbeddingsService(OpenAIAsyncClient openAIClient, String model) { public CompletableFuture>> computeEmbeddings(List texts) { try { EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(texts); - CompletableFuture>> result = - openAIClient - .getEmbeddings(model, embeddingsOptions) - .toFuture() - .thenApply( - embeddings -> - embeddings.getData().stream() - .map(embedding -> embedding.getEmbedding()) - .collect(Collectors.toList())); + return openAIClient + .getEmbeddings(model, embeddingsOptions) + .toFuture() + .thenApply( + embeddings -> + embeddings.getData().stream() + .map(embedding -> embedding.getEmbedding()) + .collect(Collectors.toList())); - // we need to wait, in order to guarantee ordering - // TODO: we should not wait, but instead use an ordered executor (per key) - try { - result.join(); - } catch (Throwable err) { - log.error("Cannot compute embeddings", err); - } - return result; } catch (RuntimeException err) { log.error("Cannot compute embeddings", err); return CompletableFuture.failedFuture(err); diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/config/ComputeAIEmbeddingsConfig.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/config/ComputeAIEmbeddingsConfig.java index 89cae394b..be5d96f5b 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/config/ComputeAIEmbeddingsConfig.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/config/ComputeAIEmbeddingsConfig.java @@ -34,6 +34,9 @@ public class ComputeAIEmbeddingsConfig extends StepConfig { @JsonProperty("batch-size") private int batchSize = 10; + @JsonProperty("concurrency") + private int concurrency = 4; + // we disable flushing by default in order to avoid latency spikes // you should enable this feature in the case of background processing @JsonProperty("flush-interval") diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/util/TransformFunctionUtil.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/util/TransformFunctionUtil.java index b6c4b1cf2..ff3da0ceb 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/util/TransformFunctionUtil.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/util/TransformFunctionUtil.java @@ -78,10 +78,6 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import java.util.function.Predicate; import javax.net.ssl.SSLContext; import javax.net.ssl.TrustManager; @@ -310,6 +306,7 @@ public static TransformStep newComputeAIEmbeddings( config.getEmbeddingsFieldName(), config.getBatchSize(), config.getFlushInterval(), + config.getConcurrency(), embeddingsService); } @@ -569,78 +566,4 @@ public X509Certificate[] getAcceptedIssuers() { } } } - - /** - * Aggregate records in batches, depending on a batch size and a maximum idle time. - * - * @param - */ - public static class BatchExecutor { - private final int batchSize; - private List batch; - private long flushInterval; - private ScheduledExecutorService scheduledExecutorService; - - private ScheduledFuture scheduledFuture; - - private final Consumer> processor; - - public BatchExecutor( - int batchSize, - Consumer> processor, - long maxIdleTime, - ScheduledExecutorService scheduledExecutorService) { - this.batchSize = batchSize; - this.batch = new ArrayList<>(batchSize); - this.processor = processor; - this.flushInterval = maxIdleTime; - this.scheduledExecutorService = scheduledExecutorService; - } - - public void start() { - if (flushInterval > 0) { - scheduledFuture = - scheduledExecutorService.scheduleWithFixedDelay( - this::flush, flushInterval, flushInterval, TimeUnit.MILLISECONDS); - } - } - - public void stop() { - if (scheduledFuture != null) { - scheduledFuture.cancel(false); - } - flush(); - } - - private void flush() { - List batchToProcess = null; - synchronized (this) { - if (batch.isEmpty()) { - return; - } - if (!batch.isEmpty()) { - batchToProcess = batch; - batch = new ArrayList<>(batchSize); - } - } - // execute the processor our of the synchronized block - processor.accept(batchToProcess); - } - - public void add(T t) { - List batchToProcess = null; - synchronized (this) { - batch.add(t); - if (batch.size() >= batchSize || flushInterval <= 0) { - batchToProcess = batch; - batch = new ArrayList<>(batchSize); - } - } - - // execute the processor our of the synchronized block - if (batchToProcess != null) { - processor.accept(batchToProcess); - } - } - } } diff --git a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/ComputeAIEmbeddingsTest.java b/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/ComputeAIEmbeddingsTest.java index 2face4bbc..9bdb6640d 100644 --- a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/ComputeAIEmbeddingsTest.java +++ b/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/ComputeAIEmbeddingsTest.java @@ -66,6 +66,7 @@ void testAvro() throws Exception { "value.newField", 1, 500, + 1, mockService); Record outputRecord = Utils.process(record, step); @@ -84,7 +85,7 @@ void testKeyValueAvro() throws Exception { mockService.setEmbeddingsForText("key1", expectedEmbeddings); ComputeAIEmbeddingsStep step = new ComputeAIEmbeddingsStep( - "{{ key.keyField1 }}", "value.newField", 1, 500, mockService); + "{{ key.keyField1 }}", "value.newField", 1, 500, 1, mockService); Record outputRecord = Utils.process(Utils.createTestAvroKeyValueRecord(), step); KeyValueSchema messageSchema = (KeyValueSchema) outputRecord.getSchema(); @@ -123,6 +124,7 @@ void testJson() throws Exception { "value.newField", 1, 500, + 1, mockService); Record outputRecord = Utils.process(record, step); diff --git a/langstream-api/src/main/java/ai/langstream/api/util/BatchExecutor.java b/langstream-api/src/main/java/ai/langstream/api/util/BatchExecutor.java new file mode 100644 index 000000000..8802cb236 --- /dev/null +++ b/langstream-api/src/main/java/ai/langstream/api/util/BatchExecutor.java @@ -0,0 +1,95 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.api.util; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +/** + * Aggregate records in batches, depending on a batch size and a maximum idle time. + * + * @param + */ +public class BatchExecutor { + private final int batchSize; + private List batch; + private long flushInterval; + private ScheduledExecutorService scheduledExecutorService; + + private ScheduledFuture scheduledFuture; + + private final Consumer> processor; + + public BatchExecutor( + int batchSize, + Consumer> processor, + long maxIdleTime, + ScheduledExecutorService scheduledExecutorService) { + this.batchSize = batchSize; + this.batch = new ArrayList<>(batchSize); + this.processor = processor; + this.flushInterval = maxIdleTime; + this.scheduledExecutorService = scheduledExecutorService; + } + + public void start() { + if (flushInterval > 0) { + scheduledFuture = + scheduledExecutorService.scheduleWithFixedDelay( + this::flush, flushInterval, flushInterval, TimeUnit.MILLISECONDS); + } + } + + public void stop() { + if (scheduledFuture != null) { + scheduledFuture.cancel(false); + } + flush(); + } + + private void flush() { + List batchToProcess = null; + synchronized (this) { + if (batch.isEmpty()) { + return; + } + batchToProcess = batch; + batch = new ArrayList<>(batchSize); + } + // execute the processor out of the synchronized block + processor.accept(batchToProcess); + } + + public void add(T t) { + List batchToProcess = null; + synchronized (this) { + batch.add(t); + if (batch.size() >= batchSize || flushInterval <= 0) { + batchToProcess = batch; + batch = new ArrayList<>(batchSize); + } + } + + // execute the processor out of the synchronized block + if (batchToProcess != null) { + processor.accept(batchToProcess); + } + } +} diff --git a/langstream-api/src/main/java/ai/langstream/api/util/OrderedAsyncBatchExecutor.java b/langstream-api/src/main/java/ai/langstream/api/util/OrderedAsyncBatchExecutor.java new file mode 100644 index 000000000..ba476990c --- /dev/null +++ b/langstream-api/src/main/java/ai/langstream/api/util/OrderedAsyncBatchExecutor.java @@ -0,0 +1,174 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.api.util; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Queue; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.Function; +import lombok.extern.slf4j.Slf4j; + +/** + * Aggregate records in batches, depending on a batch size and a maximum idle time. + * + * @param + */ +@Slf4j +public class OrderedAsyncBatchExecutor { + private final int batchSize; + private final Bucket[] buckets; + private final int numBuckets; + private final long flushInterval; + private final ScheduledExecutorService scheduledExecutorService; + + private ScheduledFuture scheduledFuture; + + private final BiConsumer, CompletableFuture> processor; + + private final Function hashFunction; + + public OrderedAsyncBatchExecutor( + int batchSize, + BiConsumer, CompletableFuture> processor, + long maxIdleTime, + int numBuckets, + Function hashFunction, + ScheduledExecutorService scheduledExecutorService) { + this.numBuckets = numBuckets; + this.hashFunction = hashFunction; + Object[] buckets = new Object[numBuckets]; + for (int i = 0; i < numBuckets; i++) { + buckets[i] = new Bucket(); + } + // try to avoid "generic array creation compile error" + this.buckets = Arrays.copyOf(buckets, numBuckets, Bucket[].class); + this.batchSize = batchSize; + this.processor = processor; + this.flushInterval = maxIdleTime; + this.scheduledExecutorService = scheduledExecutorService; + } + + public void start() { + if (flushInterval > 0) { + scheduledFuture = + scheduledExecutorService.scheduleWithFixedDelay( + this::flush, flushInterval, flushInterval, TimeUnit.MILLISECONDS); + } + } + + public void stop() { + if (scheduledFuture != null) { + scheduledFuture.cancel(false); + } + flush(); + } + + private void flush() { + for (Bucket bucket : buckets) { + bucket.flush(); + } + } + + private Bucket bucket(int hash) { + return buckets[Math.abs(hash % numBuckets)]; + } + + public void add(T t) { + int hash = hashFunction.apply(t); + Bucket bucket = bucket(hash); + bucket.add(t); + } + + private class Bucket { + private final Queue> pendingBatches = new ArrayDeque<>(); + private final List currentBatch = new ArrayList<>(); + + private final AtomicReference processing = new AtomicReference<>(); + + synchronized void add(T t) { + currentBatch.add(t); + if (currentBatch.size() >= batchSize || flushInterval <= 0) { + scheduleCurrentBatchExecution(); + } + } + + private synchronized void scheduleCurrentBatchExecution() { + if (currentBatch.isEmpty()) { + return; + } + List batchToProcess = new ArrayList<>(currentBatch); + currentBatch.clear(); + addToPendingBatches(batchToProcess); + } + + private synchronized void processNextBatch() { + if (pendingBatches.isEmpty()) { + return; + } + List nextBatch = pendingBatches.poll(); + executeBatch(nextBatch); + } + + private synchronized void addToPendingBatches(List batchToProcess) { + if (processing.get() == null) { + executeBatch(batchToProcess); + } else { + pendingBatches.add(batchToProcess); + } + } + + private void executeBatch(List batchToProcess) { + UUID batchId = UUID.randomUUID(); + CompletableFuture currentBatchHandle = new CompletableFuture<>(); + currentBatchHandle.whenComplete( + (result, error) -> { + boolean check = processing.compareAndSet(batchId, null); + if (!check) { + log.error( + "Something went wrong, batch {} was not processed", + processing.get()); + } else { + if (log.isDebugEnabled()) { + log.debug("Batch {} completed", batchId); + } + processNextBatch(); + } + }); + boolean check = processing.compareAndSet(null, batchId); + if (!check) { + throw new IllegalStateException( + "Something went wrong, the processor is still processing"); + } + if (log.isDebugEnabled()) { + log.debug("Batch {} in bucket {} started for {}", batchId, this, batchToProcess); + } + processor.accept(batchToProcess, currentBatchHandle); + } + + private synchronized void flush() { + scheduleCurrentBatchExecution(); + } + } +} diff --git a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/util/TransformFunctionUtilTest.java b/langstream-api/src/test/java/ai/langstream/api/BatchExecutorTest.java similarity index 90% rename from langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/util/TransformFunctionUtilTest.java rename to langstream-api/src/test/java/ai/langstream/api/BatchExecutorTest.java index 0b4eba809..35ead2b39 100644 --- a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/util/TransformFunctionUtilTest.java +++ b/langstream-api/src/test/java/ai/langstream/api/BatchExecutorTest.java @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.datastax.oss.streaming.ai.util; +package ai.langstream.api; import static org.junit.jupiter.api.Assertions.*; +import ai.langstream.api.util.BatchExecutor; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; @@ -26,7 +27,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -class TransformFunctionUtilTest { +class BatchExecutorTest { static Object[][] batchSizes() { return new Object[][] { @@ -50,8 +51,8 @@ void executeInBatchesTestWithFlushInterval(int numRecords, int batchSize) { records.add("text " + i); } List result = new CopyOnWriteArrayList<>(); - TransformFunctionUtil.BatchExecutor executor = - new TransformFunctionUtil.BatchExecutor<>( + BatchExecutor executor = + new BatchExecutor<>( batchSize, (batch) -> { result.addAll(batch); @@ -77,8 +78,8 @@ void executeInBatchesTestWithNoFlushInterval(int numRecords, int batchSize) { records.add("text " + i); } List result = new CopyOnWriteArrayList<>(); - TransformFunctionUtil.BatchExecutor executor = - new TransformFunctionUtil.BatchExecutor<>( + BatchExecutor executor = + new BatchExecutor<>( batchSize, (batch) -> { result.addAll(batch); diff --git a/langstream-api/src/test/java/ai/langstream/api/OrderedAsyncBatchExecutorTest.java b/langstream-api/src/test/java/ai/langstream/api/OrderedAsyncBatchExecutorTest.java new file mode 100644 index 000000000..48ea3773f --- /dev/null +++ b/langstream-api/src/test/java/ai/langstream/api/OrderedAsyncBatchExecutorTest.java @@ -0,0 +1,303 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.api; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import ai.langstream.api.util.OrderedAsyncBatchExecutor; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import lombok.extern.slf4j.Slf4j; +import org.awaitility.Awaitility; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +@Slf4j +class OrderedAsyncBatchExecutorTest { + + static Object[][] batchSizes() { + return new Object[][] { + new Object[] {0, 1}, + new Object[] {1, 1}, + new Object[] {1, 2}, + new Object[] {2, 1}, + new Object[] {2, 2}, + new Object[] {3, 5}, + new Object[] {5, 3} + }; + } + + @ParameterizedTest + @MethodSource("batchSizes") + void executeInBatchesTestWithFlushInterval(int numRecords, int batchSize) { + long flushInterval = 1000; + int numBuckets = 4; + Function hashFunction = String::hashCode; + ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor(); + List records = new ArrayList<>(); + for (int i = 0; i < numRecords; i++) { + records.add("text " + i); + } + + List result = new CopyOnWriteArrayList<>(); + OrderedAsyncBatchExecutor executor = + new OrderedAsyncBatchExecutor<>( + batchSize, + (batch, future) -> { + result.addAll(batch); + future.complete(null); + }, + flushInterval, + numBuckets, + hashFunction, + executorService); + executor.start(); + records.forEach(executor::add); + // this may happen after some time + Awaitility.await().untilAsserted(() -> assertEqualsInAnyOrder(records, result)); + executor.stop(); + executorService.shutdown(); + } + + private static void assertEqualsInAnyOrder(List a, List b) { + if (a.size() != b.size()) { + throw new AssertionError("Lists have different sizes"); + } + assertEquals(new HashSet<>(a), new HashSet<>(b)); + } + + @ParameterizedTest + @MethodSource("batchSizes") + void executeInBatchesTestWithNoFlushInterval(int numRecords, int batchSize) { + long flushInterval = 0; + int numBuckets = 4; + Function hashFunction = String::hashCode; + // this is not needed + ScheduledExecutorService executorService = null; + List records = new ArrayList<>(); + for (int i = 0; i < numRecords; i++) { + records.add("text " + i); + } + List result = new CopyOnWriteArrayList<>(); + OrderedAsyncBatchExecutor executor = + new OrderedAsyncBatchExecutor<>( + batchSize, + (batch, future) -> { + // as flush interval is zero, this is called immediately + log.info("processed batch {}", batch); + assertEquals(1, batch.size()); + result.addAll(batch); + future.complete(null); + }, + flushInterval, + numBuckets, + hashFunction, + executorService); + executor.start(); + records.forEach(executor::add); + // this may happen after some time + Awaitility.await().untilAsserted(() -> assertEquals(records, result)); + executor.stop(); + } + + static Object[][] batchSizesAndDelays() { + return new Object[][] { + new Object[] {0, 1, 0}, + new Object[] {1, 1, 0}, + new Object[] {1, 2, 0}, + new Object[] {2, 1, 0}, + new Object[] {2, 2, 0}, + new Object[] {3, 5, 0}, + new Object[] {5, 3, 0}, + new Object[] {37, 5, 0}, + new Object[] {50, 3, 0}, + new Object[] {0, 1, 200}, + new Object[] {1, 1, 200}, + new Object[] {1, 2, 200}, + new Object[] {2, 1, 200}, + new Object[] {2, 2, 200}, + new Object[] {3, 5, 200}, + new Object[] {5, 3, 200}, + new Object[] {37, 5, 200}, + new Object[] {50, 3, 200} + }; + } + + @ParameterizedTest + @MethodSource("batchSizesAndDelays") + void executeInBatchesTestWithKeyOrdering(int numRecords, int batchSize, long delay) { + + record KeyValue(int key, String value) {} + + long flushInterval = 1000; + int numBuckets = 4; + Function hashFunction = KeyValue::key; + ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor(); + ScheduledExecutorService completionsExecutorService = Executors.newScheduledThreadPool(4); + + List records = new ArrayList<>(); + Map> recordsByKey = new HashMap<>(); + + for (int i = 0; i < numRecords; i++) { + int key = i % 7; + KeyValue record = new KeyValue(key, "text " + i); + records.add(record); + List valuesForKey = + recordsByKey.computeIfAbsent(record.key, l -> new ArrayList<>()); + valuesForKey.add(record); + } + + Random random = new Random(); + + Map> results = new ConcurrentHashMap<>(); + OrderedAsyncBatchExecutor executor = + new OrderedAsyncBatchExecutor<>( + batchSize, + (batch, future) -> { + for (KeyValue record : batch) { + List resultsForKey = + results.computeIfAbsent( + record.key, l -> new CopyOnWriteArrayList<>()); + resultsForKey.add(record); + } + if (delay == 0) { + // execute the completions in the same thread + future.complete(null); + } else { + // delay the completion of the future, and do it in a separate + // thread + long delayMillis = random.nextInt((int) delay); + completionsExecutorService.schedule( + () -> future.complete(null), + delayMillis, + TimeUnit.MILLISECONDS); + } + }, + flushInterval, + numBuckets, + hashFunction, + executorService); + executor.start(); + records.forEach(executor::add); + // this may happen after some time + Awaitility.await() + .untilAsserted( + () -> { + recordsByKey.forEach( + (key, values) -> { + List resultsForKey = results.get(key); + log.info( + "key: {}, values: {}, results: {}", + key, + values, + resultsForKey); + // the order must be preserved + assertEquals(values, resultsForKey); + }); + }); + executor.stop(); + executorService.shutdown(); + completionsExecutorService.shutdown(); + } + + @ParameterizedTest + @MethodSource("batchSizesAndDelays") + void executeInBatchesTestWithKeyOrderingWithoutFlush( + int numRecords, int batchSize, long delay) { + + record KeyValue(int key, String value) {} + + long flushInterval = 0; + int numBuckets = 4; + Function hashFunction = KeyValue::key; + ScheduledExecutorService executorService = null; + ScheduledExecutorService completionsExecutorService = Executors.newScheduledThreadPool(4); + + List records = new ArrayList<>(); + Map> recordsByKey = new HashMap<>(); + + for (int i = 0; i < numRecords; i++) { + int key = i % 7; + KeyValue record = new KeyValue(key, "text " + i); + records.add(record); + List valuesForKey = + recordsByKey.computeIfAbsent(record.key, l -> new ArrayList<>()); + valuesForKey.add(record); + } + + Random random = new Random(); + + Map> results = new ConcurrentHashMap<>(); + OrderedAsyncBatchExecutor executor = + new OrderedAsyncBatchExecutor<>( + batchSize, + (batch, future) -> { + for (KeyValue record : batch) { + List resultsForKey = + results.computeIfAbsent( + record.key, l -> new CopyOnWriteArrayList<>()); + resultsForKey.add(record); + } + if (delay == 0) { + // execute the completions in the same thread + future.complete(null); + } else { + // delay the completion of the future, and do it in a separate + // thread + long delayMillis = random.nextInt((int) delay); + completionsExecutorService.schedule( + () -> future.complete(null), + delayMillis, + TimeUnit.MILLISECONDS); + } + }, + flushInterval, + numBuckets, + hashFunction, + executorService); + executor.start(); + records.forEach(executor::add); + // this flushes the pending records + executor.stop(); + // this may happen after some time + Awaitility.await() + .untilAsserted( + () -> { + recordsByKey.forEach( + (key, values) -> { + List resultsForKey = results.get(key); + log.info( + "key: {}, values: {}, results: {}", + key, + values, + resultsForKey); + // the order must be preserved + assertEquals(values, resultsForKey); + }); + }); + completionsExecutorService.shutdown(); + } +} diff --git a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/GenAIToolKitFunctionAgentProvider.java b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/GenAIToolKitFunctionAgentProvider.java index dcce509d5..17955ea5f 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/GenAIToolKitFunctionAgentProvider.java +++ b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/GenAIToolKitFunctionAgentProvider.java @@ -169,6 +169,12 @@ public void generateSteps( originalConfiguration, "batch-size", null); + optionalField( + step, + agentConfiguration, + originalConfiguration, + "concurrency", + null); optionalField( step, agentConfiguration, diff --git a/langstream-runtime/langstream-runtime-impl/pom.xml b/langstream-runtime/langstream-runtime-impl/pom.xml index 0198844ce..b8dce7ebf 100644 --- a/langstream-runtime/langstream-runtime-impl/pom.xml +++ b/langstream-runtime/langstream-runtime-impl/pom.xml @@ -416,7 +416,7 @@ copy - package + generate-test-sources copy diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/AbstractApplicationRunner.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/AbstractApplicationRunner.java index 27227b09f..0a6131f57 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/AbstractApplicationRunner.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/AbstractApplicationRunner.java @@ -40,9 +40,9 @@ import java.nio.file.Path; import java.time.Duration; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -210,9 +210,15 @@ protected void sendMessage(String topic, Object content, KafkaProducer producer) protected void sendMessage( String topic, Object content, List
headers, KafkaProducer producer) throws Exception { + sendMessage(topic, "key", content, headers, producer); + } + + protected void sendMessage( + String topic, Object key, Object content, List
headers, KafkaProducer producer) + throws Exception { producer.send( new ProducerRecord<>( - topic, null, System.currentTimeMillis(), "key", content, headers)) + topic, null, System.currentTimeMillis(), key, content, headers)) .get(); producer.flush(); } @@ -253,7 +259,7 @@ protected List waitForMessages(KafkaConsumer consumer, List e } protected List waitForMessagesInAnyOrder( - KafkaConsumer consumer, Set expected) { + KafkaConsumer consumer, Collection expected) { List result = new ArrayList<>(); List received = new ArrayList<>(); @@ -282,6 +288,17 @@ protected List waitForMessagesInAnyOrder( + " not found in " + received); } + + for (Object receivedValue : received) { + // this doesn't work for byte[] + assertFalse(receivedValue instanceof byte[]); + assertTrue( + expected.contains(receivedValue), + "Received value " + + receivedValue + + " not found in " + + expected); + } }); return result; diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ComputeEmbeddingsIT.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ComputeEmbeddingsIT.java index 5359e1016..abdc41188 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ComputeEmbeddingsIT.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ComputeEmbeddingsIT.java @@ -19,9 +19,7 @@ import static com.github.tomakehurst.wiremock.client.WireMock.okJson; import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; import ai.langstream.AbstractApplicationRunner; import ai.langstream.api.model.Application; @@ -37,17 +35,16 @@ import java.util.Map; import java.util.Set; import java.util.UUID; -import java.util.function.Consumer; import java.util.stream.Stream; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.producer.KafkaProducer; import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; @Slf4j @WireMockTest @@ -222,6 +219,7 @@ public void testComputeEmbeddings(EmbeddingsConfig config) throws Exception { model: "%s" embeddings-field: "value.embeddings" text: "something to embed" + concurrency: 1 flush-interval: 0 """ .formatted( @@ -266,8 +264,9 @@ tenant, appId, application, buildInstanceYaml(), expectedAgents)) { } } - @Test - public void testComputeBatchEmbeddings() throws Exception { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testComputeBatchEmbeddings(boolean sameKey) throws Exception { wireMockRuntimeInfo .getWireMock() .allStubMappings() @@ -277,6 +276,9 @@ public void testComputeBatchEmbeddings() throws Exception { log.info("Removing stub {}", stubMapping); wireMockRuntimeInfo.getWireMock().removeStubMapping(stubMapping); }); + String embeddingFirst = "[1.0,5.4,8.7]"; + String embeddingSecond = "[2.0,5.4,8.7]"; + String embeddingThird = "[3.0,5.4,8.7]"; stubFor( post("/openai/deployments/text-embedding-ada-002/embeddings?api-version=2023-08-01-preview") .willReturn( @@ -285,17 +287,17 @@ public void testComputeBatchEmbeddings() throws Exception { { "data": [ { - "embedding": [1.0, 5.4, 8.7], + "embedding": %s, "index": 0, "object": "embedding" }, { - "embedding": [2.0, 5.4, 8.7], + "embedding": %s, "index": 0, "object": "embedding" }, { - "embedding": [3.0, 5.4, 8.7], + "embedding": %s, "index": 0, "object": "embedding" } @@ -307,7 +309,11 @@ public void testComputeBatchEmbeddings() throws Exception { "total_tokens": 5 } } - """))); + """ + .formatted( + embeddingFirst, + embeddingSecond, + embeddingThird)))); // wait for WireMock to be ready Thread.sleep(1000); @@ -356,6 +362,7 @@ public void testComputeBatchEmbeddings() throws Exception { embeddings-field: "value.embeddings" text: "something to embed" batch-size: 3 + concurrency: 4 flush-interval: 10000 """ .formatted( @@ -381,39 +388,56 @@ tenant, appId, application, buildInstanceYaml(), expectedAgents)) { KafkaConsumer consumer = createConsumer(outputTopic)) { // produce 10 messages to the input-topic - List> expected = new ArrayList<>(); + List expected = new ArrayList<>(); for (int i = 0; i < 9; i++) { String name = "name_" + i; + String key = sameKey ? "key" : "key_" + (i % 3); sendMessage( inputTopic, + key, "{\"name\": \" " + name + "\", \"description\": \"some description\"}", + List.of(), producer); - int _i = i; - expected.add( - text -> { - String embeddings = ""; - if (_i % 3 == 0) { - embeddings = "[1.0,5.4,8.7]"; - } else if (_i % 3 == 1) { - embeddings = "[2.0,5.4,8.7]"; - } else if (_i % 3 == 2) { - embeddings = "[3.0,5.4,8.7]"; - } else { - fail(); - } - assertEquals( - "{\"name\":\" " - + name - + "\",\"description\":\"some description\",\"embeddings\":" - + embeddings - + "}", - text.toString()); - }); + + String embeddings; + if (sameKey) { + if (i % 3 == 0) { + embeddings = embeddingFirst; + } else if (i % 3 == 1) { + embeddings = embeddingSecond; + } else { + embeddings = embeddingThird; + } + } else { + // this may look weird, but given the key distribution, we build 3 batches + // that contain 3 messages each + // the first 3 messages become the head of each batch, the next 3 messages + // are the second element of each batch, and so on + embeddings = + switch (i) { + case 0, 1, 2 -> embeddingFirst; + case 3, 4, 5 -> embeddingSecond; + case 6, 7, 8 -> embeddingThird; + default -> throw new IllegalStateException(); + }; + } + String expectedContent = + "{\"name\":\" " + + name + + "\",\"description\":\"some description\",\"embeddings\":" + + embeddings + + "}"; + expected.add(expectedContent); } executeAgentRunners(applicationRuntime); - waitForMessages(consumer, expected); + if (sameKey) { + // all the messages have the same key, so they must be processed in order + waitForMessages(consumer, expected); + } else { + waitForMessagesInAnyOrder(consumer, expected); + } } } }