From 5927f8c7e404786fa55d05f86721bd51e0517772 Mon Sep 17 00:00:00 2001 From: Chris Bartholomew Date: Thu, 13 Jun 2024 11:16:46 -0400 Subject: [PATCH] Openai embedding dimensions and newer Azure SDK (#62) --- .../compute-openai-embeddings/pipeline.yaml | 3 +- .../openai-completions/pipeline.yaml | 4 +- .../impl/OpenAICompletionService.java | 39 +++++-- .../embeddings/OpenAIEmbeddingsService.java | 20 +++- .../config/ComputeAIEmbeddingsConfig.java | 3 + .../ai/services/OpenAIServiceProvider.java | 4 +- .../services/impl/OpenAIProviderTest.java | 4 +- .../ai/agents/services/impl/OpenAITest.java | 107 ++++++++++++++++++ .../functions/transforms/GenAITest.java | 15 ++- .../streaming/ai/ChatCompletionsStepTest.java | 82 +++++++++++--- .../ComputeAIEmbeddingsConfiguration.java | 9 ++ ...GenAIToolKitFunctionAgentProviderTest.java | 5 + .../langstream/kafka/ChatCompletionsIT.java | 2 +- .../langstream/kafka/ComputeEmbeddingsIT.java | 99 ++++++++++++---- .../kafka/FlareControllerAgentRunnerIT.java | 36 +++++- .../langstream/kafka/TextCompletionsIT.java | 2 +- pom.xml | 4 +- 17 files changed, 374 insertions(+), 64 deletions(-) create mode 100644 langstream-agents/langstream-ai-agents/src/test/java/ai/langstream/ai/agents/services/impl/OpenAITest.java diff --git a/examples/applications/compute-openai-embeddings/pipeline.yaml b/examples/applications/compute-openai-embeddings/pipeline.yaml index e0a222050..1a894b7d5 100644 --- a/examples/applications/compute-openai-embeddings/pipeline.yaml +++ b/examples/applications/compute-openai-embeddings/pipeline.yaml @@ -35,9 +35,10 @@ pipeline: type: "compute-ai-embeddings" output: "output-topic" configuration: - model: "${secrets.open-ai.embeddings-model}" # This needs to match the name of the model deployment, not the base model + model: "text-embedding-3-large" # This needs to match the name of the model deployment, not the base model embeddings-field: "value.embeddings" text: "{{ value.text }}" + dimensions: 256 batch-size: 10 # this is in milliseconds. It is important to take this value into consideration when using this agent in the chat response pipeline # in fact this value impacts the latency of the response diff --git a/examples/applications/openai-completions/pipeline.yaml b/examples/applications/openai-completions/pipeline.yaml index bc4da214b..d624aab67 100644 --- a/examples/applications/openai-completions/pipeline.yaml +++ b/examples/applications/openai-completions/pipeline.yaml @@ -48,5 +48,7 @@ pipeline: # eventually we want to send bigger messages to reduce the overhead of each message on the topic min-chunks-per-message: 10 messages: + - role: system + content: "You are helpful assistant." - role: user - content: "You are a helpful assistant. Below you can find a question from the user. Please try to help them the best way you can.\n\n{{ value.question}}" + content: "{{ value.question}}" diff --git a/langstream-agents/langstream-ai-agents/src/main/java/ai/langstream/ai/agents/services/impl/OpenAICompletionService.java b/langstream-agents/langstream-ai-agents/src/main/java/ai/langstream/ai/agents/services/impl/OpenAICompletionService.java index 121939881..e57401b3f 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/ai/langstream/ai/agents/services/impl/OpenAICompletionService.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/ai/langstream/ai/agents/services/impl/OpenAICompletionService.java @@ -22,7 +22,10 @@ import ai.langstream.api.runner.code.MetricsReporter; import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.models.ChatCompletionsOptions; -import com.azure.ai.openai.models.ChatRole; +import com.azure.ai.openai.models.ChatRequestAssistantMessage; +import com.azure.ai.openai.models.ChatRequestMessage; +import com.azure.ai.openai.models.ChatRequestSystemMessage; +import com.azure.ai.openai.models.ChatRequestUserMessage; import com.azure.ai.openai.models.CompletionsFinishReason; import com.azure.ai.openai.models.CompletionsLogProbabilityModel; import com.azure.ai.openai.models.CompletionsOptions; @@ -124,16 +127,29 @@ public CompletableFuture getChatCompletions( StreamingChunksConsumer streamingChunksConsumer, Map options) { int minChunksPerMessage = getInteger("min-chunks-per-message", 20, options); + + List chatMessages = + messages.stream() + .map( + message -> { + switch (message.getRole()) { + case "system": + return new ChatRequestSystemMessage( + message.getContent()); + case "user": + return new ChatRequestUserMessage(message.getContent()); + case "assistant": + return new ChatRequestAssistantMessage( + message.getContent()); + default: + throw new IllegalArgumentException( + "Unknown chat role: " + message.getRole()); + } + }) + .collect(Collectors.toList()); + ChatCompletionsOptions chatCompletionsOptions = - new ChatCompletionsOptions( - messages.stream() - .map( - message -> - new com.azure.ai.openai.models.ChatMessage( - ChatRole.fromString( - message.getRole()), - message.getContent())) - .collect(Collectors.toList())) + new ChatCompletionsOptions(chatMessages) .setMaxTokens(getInteger("max-tokens", null, options)) .setTemperature(getDouble("temperature", null, options)) .setTopP(getDouble("top-p", null, options)) @@ -143,6 +159,7 @@ public CompletableFuture getChatCompletions( .setStop((List) options.get("stop")) .setPresencePenalty(getDouble("presence-penalty", null, options)) .setFrequencyPenalty(getDouble("frequency-penalty", null, options)); + ChatCompletions result = new ChatCompletions(); chatNumCalls.count(1); // this is the default behavior, as it is async @@ -211,7 +228,7 @@ public CompletableFuture getChatCompletions( } private static ChatMessage convertMessage(com.azure.ai.openai.models.ChatChoice c) { - com.azure.ai.openai.models.ChatMessage message = c.getMessage(); + com.azure.ai.openai.models.ChatResponseMessage message = c.getMessage(); if (message == null) { message = c.getDelta(); } diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/embeddings/OpenAIEmbeddingsService.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/embeddings/OpenAIEmbeddingsService.java index ebaa57399..78efb001d 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/embeddings/OpenAIEmbeddingsService.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/embeddings/OpenAIEmbeddingsService.java @@ -30,6 +30,7 @@ public class OpenAIEmbeddingsService implements EmbeddingsService { private final OpenAIAsyncClient openAIClient; private final String model; + private final Integer dimensions; private final MetricsReporter.Counter totalTokens; private final MetricsReporter.Counter promptTokens; @@ -38,9 +39,13 @@ public class OpenAIEmbeddingsService implements EmbeddingsService { private final MetricsReporter.Counter numErrors; public OpenAIEmbeddingsService( - OpenAIAsyncClient openAIClient, String model, MetricsReporter metricsReporter) { + OpenAIAsyncClient openAIClient, + String model, + MetricsReporter metricsReporter, + Integer dimensions) { this.openAIClient = openAIClient; this.model = model; + this.dimensions = dimensions; this.totalTokens = metricsReporter.counter( "openai_embeddings_total_tokens", @@ -65,8 +70,13 @@ public OpenAIEmbeddingsService( public CompletableFuture>> computeEmbeddings(List texts) { try { EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(texts); + if (dimensions > 0) { + log.debug("Setting embedding dimensions to {}", dimensions); + embeddingsOptions.setDimensions(dimensions); + } numCalls.count(1); numTexts.count(texts.size()); + CompletableFuture>> result = openAIClient .getEmbeddings(model, embeddingsOptions) @@ -78,12 +88,12 @@ public CompletableFuture>> computeEmbeddings(List text promptTokens.count(usage.getPromptTokens()); return embeddings.getData().stream() .map(EmbeddingItem::getEmbedding) + .map(this::convertToDoubleList) .collect(Collectors.toList()); }); result.exceptionally( err -> { - // API call error numErrors.count(1); return null; }); @@ -94,4 +104,10 @@ public CompletableFuture>> computeEmbeddings(List text return CompletableFuture.failedFuture(err); } } + + private List convertToDoubleList(List floatList) { + // Log the length of the float list + log.debug("Float list length: {}", floatList.size()); + return floatList.stream().map(Float::doubleValue).collect(Collectors.toList()); + } } diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/config/ComputeAIEmbeddingsConfig.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/config/ComputeAIEmbeddingsConfig.java index 10e475b6a..ca51991a0 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/config/ComputeAIEmbeddingsConfig.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/model/config/ComputeAIEmbeddingsConfig.java @@ -55,4 +55,7 @@ public class ComputeAIEmbeddingsConfig extends StepConfig { @JsonProperty(value = "model-url") String modelUrl; + + @JsonProperty(value = "dimensions") + private int dimensions = 0; } diff --git a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/services/OpenAIServiceProvider.java b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/services/OpenAIServiceProvider.java index d6eefe249..d4dccb57b 100644 --- a/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/services/OpenAIServiceProvider.java +++ b/langstream-agents/langstream-ai-agents/src/main/java/com/datastax/oss/streaming/ai/services/OpenAIServiceProvider.java @@ -48,7 +48,9 @@ public CompletionsService getCompletionsService(Map additionalCo @Override public EmbeddingsService getEmbeddingsService(Map additionalConfiguration) { String model = (String) additionalConfiguration.get("model"); - return new OpenAIEmbeddingsService(client, model, metricsReporter); + // Get dimensions from the configuration + Integer dimensions = (Integer) additionalConfiguration.get("dimensions"); + return new OpenAIEmbeddingsService(client, model, metricsReporter, dimensions); } @Override diff --git a/langstream-agents/langstream-ai-agents/src/test/java/ai/langstream/ai/agents/services/impl/OpenAIProviderTest.java b/langstream-agents/langstream-ai-agents/src/test/java/ai/langstream/ai/agents/services/impl/OpenAIProviderTest.java index 24e4bba8f..972750a5b 100644 --- a/langstream-agents/langstream-ai-agents/src/test/java/ai/langstream/ai/agents/services/impl/OpenAIProviderTest.java +++ b/langstream-agents/langstream-ai-agents/src/test/java/ai/langstream/ai/agents/services/impl/OpenAIProviderTest.java @@ -46,7 +46,7 @@ class OpenAIProviderTest { void testStreamingChatCompletion(WireMockRuntimeInfo vmRuntimeInfo) throws Exception { resetWiremockStubs(vmRuntimeInfo); stubFor( - post("/openai/deployments/gpt-35-turbo/chat/completions?api-version=2023-08-01-preview") + post("/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-03-01-preview") .willReturn( okJson( """ @@ -127,7 +127,7 @@ public void consumeChunk( void testStreamingTextCompletion(WireMockRuntimeInfo vmRuntimeInfo) throws Exception { resetWiremockStubs(vmRuntimeInfo); stubFor( - post("/openai/deployments/gpt-35-turbo-instruct/completions?api-version=2023-08-01-preview") + post("/openai/deployments/gpt-35-turbo-instruct/completions?api-version=2024-03-01-preview") .withRequestBody( equalTo( "{\"prompt\":[\"Translate from English to Italian: \\\"I love cars\\\" with quotes\"],\"stream\":true}")) diff --git a/langstream-agents/langstream-ai-agents/src/test/java/ai/langstream/ai/agents/services/impl/OpenAITest.java b/langstream-agents/langstream-ai-agents/src/test/java/ai/langstream/ai/agents/services/impl/OpenAITest.java new file mode 100644 index 000000000..99bbfcf6d --- /dev/null +++ b/langstream-agents/langstream-ai-agents/src/test/java/ai/langstream/ai/agents/services/impl/OpenAITest.java @@ -0,0 +1,107 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ai.langstream.ai.agents.services.impl; + +import com.azure.ai.openai.OpenAIAsyncClient; +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatRequestMessage; +import com.azure.ai.openai.models.ChatRequestUserMessage; +import com.azure.ai.openai.models.EmbeddingItem; +import com.azure.ai.openai.models.Embeddings; +import com.azure.ai.openai.models.EmbeddingsOptions; +import com.azure.ai.openai.models.EmbeddingsUsage; +import com.azure.core.credential.AzureKeyCredential; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +public class OpenAITest { + + private OpenAIAsyncClient openAIClient; + + @BeforeEach + void setup() { + openAIClient = + new OpenAIClientBuilder() + .credential(new AzureKeyCredential("YOUR_OPENAI_KEY")) + .buildAsyncClient(); + } + + @Disabled + @Test + void testRealChatCompletions() throws Exception { + List chatMessages = new ArrayList<>(); + chatMessages.add(new ChatRequestUserMessage("Name the US presidents of the 20th century")); + + ChatCompletionsOptions options = new ChatCompletionsOptions(chatMessages); + + CompletableFuture completableFuture = new CompletableFuture<>(); + + openAIClient + .getChatCompletionsStream("gpt-3.5-turbo", options) + .doOnNext( + chatCompletions -> { + String response = + chatCompletions.getChoices().get(0).getDelta().getContent(); + System.out.println("Response: " + response); + if (response == null) { + completableFuture.complete(null); + } + }) + .doOnError(completableFuture::completeExceptionally) + .subscribe(); + + completableFuture.join(); + } + + @Disabled + @Test + void testRealEmbeddings() throws Exception { + String azureOpenaiKey = ""; + String deploymentOrModelId = "text-embedding-ada-002"; + + OpenAIClient client = + new OpenAIClientBuilder() + .credential(new AzureKeyCredential(azureOpenaiKey)) + .buildClient(); + + EmbeddingsOptions embeddingsOptions = + new EmbeddingsOptions(Arrays.asList("Your text string goes here")); + + Embeddings embeddings = client.getEmbeddings(deploymentOrModelId, embeddingsOptions); + + for (EmbeddingItem item : embeddings.getData()) { + System.out.printf("Index: %d.%n", item.getPromptIndex()); + System.out.println( + "Embedding as base64 encoded string: " + item.getEmbeddingAsString()); + System.out.println("Embedding as list of floats: "); + for (Float embedding : item.getEmbedding()) { + System.out.printf("%f;", embedding); + } + } + + EmbeddingsUsage usage = embeddings.getUsage(); + System.out.printf( + "Usage: number of prompt token is %d and number of total tokens in request and response is %d.%n", + usage.getPromptTokens(), usage.getTotalTokens()); + } +} diff --git a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/pulsar/functions/transforms/GenAITest.java b/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/pulsar/functions/transforms/GenAITest.java index 983b3c512..ef5fcdca5 100644 --- a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/pulsar/functions/transforms/GenAITest.java +++ b/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/pulsar/functions/transforms/GenAITest.java @@ -29,6 +29,8 @@ import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatRequestMessage; +import com.azure.ai.openai.models.ChatRequestUserMessage; import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource; import com.datastax.oss.streaming.ai.services.OpenAIServiceProvider; import com.fasterxml.jackson.databind.ObjectMapper; @@ -221,7 +223,18 @@ void testChatCompletions() throws Exception { ArgumentCaptor.forClass(ChatCompletionsOptions.class); verify(client).getChatCompletionsStream(eq("test-model"), captor.capture()); - assertEquals(captor.getValue().getMessages().get(0).getContent(), "value1 key2"); + captor.getAllValues().forEach(System.out::println); + + List messages = captor.getValue().getMessages(); + ChatRequestMessage firstMessage = messages.get(0); + + if (firstMessage instanceof ChatRequestUserMessage) { + ChatRequestUserMessage userMessage = (ChatRequestUserMessage) firstMessage; + String messageContent = userMessage.getContent().toString(); + assertEquals("value1 key2", messageContent); + } else { + throw new AssertionError("Expected first message to be of type ChatRequestUserMessage"); + } } @Test diff --git a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/ChatCompletionsStepTest.java b/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/ChatCompletionsStepTest.java index b4ea5ea1e..bbcb3dcc8 100644 --- a/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/ChatCompletionsStepTest.java +++ b/langstream-agents/langstream-ai-agents/src/test/java/com/datastax/oss/streaming/ai/ChatCompletionsStepTest.java @@ -28,6 +28,8 @@ import com.azure.ai.openai.OpenAIAsyncClient; import com.azure.ai.openai.models.ChatCompletions; import com.azure.ai.openai.models.ChatCompletionsOptions; +import com.azure.ai.openai.models.ChatRequestMessage; +import com.azure.ai.openai.models.ChatRequestUserMessage; import com.datastax.oss.streaming.ai.completions.ChatMessage; import com.datastax.oss.streaming.ai.model.config.ChatCompletionsConfig; import com.fasterxml.jackson.databind.JsonNode; @@ -119,10 +121,18 @@ void testPrimitive() throws Exception { "{{ value }} {{ key}} {{ eventTime }} {{ topicName }} {{ destinationTopic }} {{ properties.test-key }}"))); Utils.process(record, new ChatCompletionsStep(completionService, config)); verify(openAIClient).getChatCompletionsStream(eq("test-model"), captor.capture()); - - assertEquals( - captor.getValue().getMessages().get(0).getContent(), - "test-message test-key 42 test-input-topic test-output-topic test-value"); + List messages = captor.getValue().getMessages(); + ChatRequestMessage firstMessage = messages.get(0); + + if (firstMessage instanceof ChatRequestUserMessage) { + ChatRequestUserMessage userMessage = (ChatRequestUserMessage) firstMessage; + String messageContent = userMessage.getContent().toString(); + assertEquals( + "test-message test-key 42 test-input-topic test-output-topic test-value", + messageContent); + } else { + throw new AssertionError("Expected first message to be of type ChatRequestUserMessage"); + } } @Test @@ -153,9 +163,18 @@ void testPrimitiveNoStream() throws Exception { Utils.process(record, new ChatCompletionsStep(completionService, config)); verify(openAIClient).getChatCompletions(eq("test-model"), captor.capture()); - assertEquals( - captor.getValue().getMessages().get(0).getContent(), - "test-message test-key 42 test-input-topic test-output-topic test-value"); + List messages = captor.getValue().getMessages(); + ChatRequestMessage firstMessage = messages.get(0); + + if (firstMessage instanceof ChatRequestUserMessage) { + ChatRequestUserMessage userMessage = (ChatRequestUserMessage) firstMessage; + String messageContent = userMessage.getContent().toString(); + assertEquals( + "test-message test-key 42 test-input-topic test-output-topic test-value", + messageContent); + } else { + throw new AssertionError("Expected first message to be of type ChatRequestUserMessage"); + } } public static Object[][] structuredSchemaTypes() { @@ -205,9 +224,16 @@ void testStructured(SchemaType schemaType) throws Exception { ArgumentCaptor.forClass(ChatCompletionsOptions.class); verify(openAIClient).getChatCompletionsStream(eq("test-model"), captor.capture()); - assertEquals( - captor.getValue().getMessages().get(0).getContent(), - "Jane Doe 42 19359 1672700645006 83045006 test-key"); + List messages = captor.getValue().getMessages(); + ChatRequestMessage firstMessage = messages.get(0); + + if (firstMessage instanceof ChatRequestUserMessage) { + ChatRequestUserMessage userMessage = (ChatRequestUserMessage) firstMessage; + String messageContent = userMessage.getContent().toString(); + assertEquals("Jane Doe 42 19359 1672700645006 83045006 test-key", messageContent); + } else { + throw new AssertionError("Expected first message to be of type ChatRequestUserMessage"); + } } public static Object[][] jsonStringSchemas() { @@ -243,9 +269,16 @@ void testJsonString(Schema schema) throws Exception { ArgumentCaptor.forClass(ChatCompletionsOptions.class); verify(openAIClient).getChatCompletionsStream(eq("test-model"), captor.capture()); - assertEquals( - captor.getValue().getMessages().get(0).getContent(), - "Jane Doe 42 19359 1672700645006 83045006 test-key"); + List messages = captor.getValue().getMessages(); + ChatRequestMessage firstMessage = messages.get(0); + + if (firstMessage instanceof ChatRequestUserMessage) { + ChatRequestUserMessage userMessage = (ChatRequestUserMessage) firstMessage; + String messageContent = userMessage.getContent().toString(); + assertEquals("Jane Doe 42 19359 1672700645006 83045006 test-key", messageContent); + } else { + throw new AssertionError("Expected first message to be of type ChatRequestUserMessage"); + } } @ParameterizedTest @@ -263,8 +296,16 @@ void testKVStructured(SchemaType schemaType) throws Exception { Utils.createTestStructKeyValueRecord(schemaType), new ChatCompletionsStep(completionService, config)); verify(openAIClient).getChatCompletionsStream(eq("test-model"), captor.capture()); + List messages = captor.getValue().getMessages(); + ChatRequestMessage firstMessage = messages.get(0); - assertEquals(captor.getValue().getMessages().get(0).getContent(), "value1 key2"); + if (firstMessage instanceof ChatRequestUserMessage) { + ChatRequestUserMessage userMessage = (ChatRequestUserMessage) firstMessage; + String messageContent = userMessage.getContent().toString(); + assertEquals("value1 key2", messageContent); + } else { + throw new AssertionError("Expected first message to be of type ChatRequestUserMessage"); + } } @ParameterizedTest @@ -297,9 +338,16 @@ void testKVJsonString(Schema schema) throws Exception { ArgumentCaptor.forClass(ChatCompletionsOptions.class); verify(openAIClient).getChatCompletionsStream(eq("test-model"), captor.capture()); - assertEquals( - captor.getValue().getMessages().get(0).getContent(), - "Jane Doe 42 19359 1672700645006 83045006 test-key"); + List messages = captor.getValue().getMessages(); + ChatRequestMessage firstMessage = messages.get(0); + + if (firstMessage instanceof ChatRequestUserMessage) { + ChatRequestUserMessage userMessage = (ChatRequestUserMessage) firstMessage; + String messageContent = userMessage.getContent().toString(); + assertEquals("Jane Doe 42 19359 1672700645006 83045006 test-key", messageContent); + } else { + throw new AssertionError("Expected first message to be of type ChatRequestUserMessage"); + } } @Test diff --git a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/ComputeAIEmbeddingsConfiguration.java b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/ComputeAIEmbeddingsConfiguration.java index b25a37be8..6358d74d8 100644 --- a/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/ComputeAIEmbeddingsConfiguration.java +++ b/langstream-core/src/main/java/ai/langstream/impl/agents/ai/steps/ComputeAIEmbeddingsConfiguration.java @@ -88,6 +88,15 @@ public void generateSteps( @JsonProperty("embeddings-field") private String embeddingsField; + @ConfigProperty( + description = + """ + Vector dimensions to use when calculating the embedding. Applies to Open AI test-embedding-3 models. + """, + required = false) + @JsonProperty("dimensions") + private Integer dimensions = 0; + @ConfigProperty( description = """ diff --git a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/agents/KubernetesGenAIToolKitFunctionAgentProviderTest.java b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/agents/KubernetesGenAIToolKitFunctionAgentProviderTest.java index c0178ec72..254ee3a03 100644 --- a/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/agents/KubernetesGenAIToolKitFunctionAgentProviderTest.java +++ b/langstream-k8s-runtime/langstream-k8s-runtime-core/src/test/java/ai/langstream/runtime/impl/k8s/agents/KubernetesGenAIToolKitFunctionAgentProviderTest.java @@ -501,6 +501,11 @@ public void testDocumentation() { "type" : "integer", "defaultValue" : "4" }, + "dimensions" : { + "description" : "Vector dimensions to use when calculating the embedding. Applies to Open AI test-embedding-3 models.", + "required" : false, + "type" : "integer" + }, "embeddings-field" : { "description" : "Field where to store the embeddings.", "required" : true, diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ChatCompletionsIT.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ChatCompletionsIT.java index d81d904db..7f8cbf677 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ChatCompletionsIT.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ChatCompletionsIT.java @@ -63,7 +63,7 @@ public void testChatCompletionWithStreaming(boolean legacy) throws Exception { String model = "gpt-35-turbo"; stubFor( - post("/openai/deployments/gpt-35-turbo/chat/completions?api-version=2023-08-01-preview") + post("/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-03-01-preview") .withRequestBody( equalTo( "{\"messages\":[{\"role\":\"user\",\"content\":\"What can you tell me about the car ?\"}],\"stream\":true}")) diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ComputeEmbeddingsIT.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ComputeEmbeddingsIT.java index 7e3f8e1d5..2db103796 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ComputeEmbeddingsIT.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/ComputeEmbeddingsIT.java @@ -30,17 +30,24 @@ import ai.langstream.api.model.TopicDefinition; import ai.langstream.api.runtime.ExecutionPlan; import ai.langstream.api.runtime.Topic; +import com.azure.core.util.Base64Util; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; import com.github.tomakehurst.wiremock.matching.MultiValuePattern; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; import java.text.SimpleDateFormat; import java.util.ArrayList; +import java.util.Base64; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.function.Consumer; +import java.util.stream.Collectors; import java.util.stream.Stream; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -135,14 +142,14 @@ private static Stream providers() { .formatted(wireMockRuntimeInfo.getHttpBaseUrl()), () -> stubFor( - post("/openai/deployments/text-embedding-ada-002/embeddings?api-version=2023-08-01-preview") + post("/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-03-01-preview") .willReturn( okJson( """ { "data": [ { - "embedding": [1.0, 5.4, 8.7], + "embedding": "AACAP83MrEAzMwtB", "index": 0, "object": "embedding" } @@ -155,7 +162,7 @@ private static Stream providers() { } } """))), - Set.of("[1.0,5.4,8.7]"))); + Set.of("[1.0,5.400000095367432,8.699999809265137]"))); Arguments huggingFaceApi = Arguments.of( new EmbeddingsConfig( @@ -372,28 +379,36 @@ public void testComputeBatchEmbeddings(boolean sameKey) throws Exception { log.info("Removing stub {}", stubMapping); wireMockRuntimeInfo.getWireMock().removeStubMapping(stubMapping); }); - String embeddingFirst = "[1.0,5.4,8.7]"; - String embeddingSecond = "[2.0,5.4,8.7]"; - String embeddingThird = "[3.0,5.4,8.7]"; + List floatListFirst = List.of(1.0f, 5.400000095367432f, 8.699999809265137f); + // Need to convert to base64 since Azure SDK always requests embeddings in base64 format + String embeddingFirst64 = convertFloatListToBase64(floatListFirst); + String embeddingFirst = convertFloatListToString(floatListFirst); + log.info("Embedding first: {}", embeddingFirst); + List floatListSecond = List.of(2.0f, 5.400000095367432f, 8.699999809265137f); + String embeddingSecond64 = convertFloatListToBase64(floatListSecond); + String embeddingSecond = convertFloatListToString(floatListSecond); + List floatListThird = List.of(3.0f, 5.400000095367432f, 8.699999809265137f); + String embeddingThird64 = convertFloatListToBase64(floatListThird); + String embeddingThird = convertFloatListToString(floatListThird); stubFor( - post("/openai/deployments/text-embedding-ada-002/embeddings?api-version=2023-08-01-preview") + post("/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-03-01-preview") .willReturn( okJson( """ { "data": [ { - "embedding": %s, + "embedding": "%s", "index": 0, "object": "embedding" }, { - "embedding": %s, + "embedding": "%s", "index": 0, "object": "embedding" }, { - "embedding": %s, + "embedding": "%s", "index": 0, "object": "embedding" } @@ -407,9 +422,9 @@ public void testComputeBatchEmbeddings(boolean sameKey) throws Exception { } """ .formatted( - embeddingFirst, - embeddingSecond, - embeddingThird)))); + embeddingFirst64, + embeddingSecond64, + embeddingThird64)))); // wait for WireMock to be ready Thread.sleep(1000); @@ -538,6 +553,44 @@ tenant, appId, application, buildInstanceYaml(), expectedAgents)) { } } + public static String convertFloatListToString(List floatList) { + return "[" + + floatList.stream() + .map(f -> BigDecimal.valueOf(f).toPlainString()) + .collect(Collectors.joining(",")) + + "]"; + } + + // This method converts a base64 string to a list of floats + public static List convertBase64ToFloatList(String embedding) { + byte[] bytes = Base64Util.decodeString(embedding); + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); + List floatList = new ArrayList<>(floatBuffer.remaining()); + while (floatBuffer.hasRemaining()) { + floatList.add(floatBuffer.get()); + } + return floatList; + } + + public static String convertFloatListToBase64(List floatList) { + ByteBuffer byteBuffer = ByteBuffer.allocate(floatList.size() * 4); + byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + for (Float f : floatList) { + byteBuffer.putFloat(f); + } + byte[] bytes = byteBuffer.array(); + return Base64.getEncoder().encodeToString(bytes); + } + + public static void main(String[] args) { + // Example usage + List floatList = List.of(1.0f, 5.4f, 8.7f); + String base64String = convertFloatListToBase64(floatList); + System.out.println(base64String); + } + @Test public void testLegacySyntax() throws Exception { wireMockRuntimeInfo @@ -549,17 +602,23 @@ public void testLegacySyntax() throws Exception { log.info("Removing stub {}", stubMapping); wireMockRuntimeInfo.getWireMock().removeStubMapping(stubMapping); }); - String embeddingFirst = "[1.0,5.4,8.7]"; + List floatListFirst = List.of(1.0f, 5.400000095367432f, 8.699999809265137f); + // Need to convert to base64 since Azure SDK always requests embeddings in base64 format + String base64String = convertFloatListToBase64(floatListFirst); + log.info("Embedding first: {}", base64String); + stubFor( - post("/openai/deployments/text-embedding-ada-002/embeddings?api-version=2023-08-01-preview") - .withRequestBody(equalTo("{\"input\":[\"something to embed foo\"]}")) + post("/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-03-01-preview") + .withRequestBody( + equalTo( + "{\"input\":[\"something to embed foo\"],\"encoding_format\":\"base64\"}")) .willReturn( okJson( """ { "data": [ { - "embedding": %s, + "embedding": "%s", "index": 0, "object": "embedding" } @@ -572,7 +631,7 @@ public void testLegacySyntax() throws Exception { } } """ - .formatted(embeddingFirst)))); + .formatted(base64String)))); // wait for WireMock to be ready Thread.sleep(1000); @@ -655,7 +714,9 @@ tenant, appId, application, buildInstanceYaml(), expectedAgents)) { executeAgentRunners(applicationRuntime); waitForMessages( - consumer, List.of("{\"name\":\"foo\",\"embeddings\":[1.0,5.4,8.7]}")); + consumer, + List.of( + "{\"name\":\"foo\",\"embeddings\":[1.0,5.400000095367432,8.699999809265137]}")); } } } diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/FlareControllerAgentRunnerIT.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/FlareControllerAgentRunnerIT.java index 91933ba44..1ea8be7ec 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/FlareControllerAgentRunnerIT.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/FlareControllerAgentRunnerIT.java @@ -21,8 +21,13 @@ import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Base64; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.clients.producer.KafkaProducer; @@ -57,16 +62,19 @@ public static void stopDatabase() { @Test public void testSimpleFlare(WireMockRuntimeInfo wireMockRuntimeInfo) throws Exception { - String embeddingFirst = "[1.0,5.4,8.7,7,9]"; + // String embeddingFirst = "[1.0,5.4,8.7,7,9]"; + List floatListFirst = List.of(1.0f, 5.400000095367432f, 8.699999809265137f, 7f, 9f); + // Need to convert to base64 since Azure SDK always requests embeddings in base64 format + String embeddingFirst64 = convertFloatListToBase64(floatListFirst); stubFor( - post("/openai/deployments/text-embeddings-ada/embeddings?api-version=2023-08-01-preview") + post("/openai/deployments/text-embeddings-ada/embeddings?api-version=2024-03-01-preview") .willReturn( okJson( """ { "data": [ { - "embedding": %s, + "embedding": "%s", "index": 0, "object": "embedding" } @@ -79,9 +87,9 @@ public void testSimpleFlare(WireMockRuntimeInfo wireMockRuntimeInfo) throws Exce } } """ - .formatted(embeddingFirst)))); + .formatted(embeddingFirst64)))); stubFor( - post("/openai/deployments/gp-3.5-turbo-instruct/completions?api-version=2023-08-01-preview") + post("/openai/deployments/gp-3.5-turbo-instruct/completions?api-version=2024-03-01-preview") .willReturn( okJson( """ @@ -376,4 +384,22 @@ ORDER BY cosine_similarity(embeddings_vector, CAST(? as FLOAT ARRAY)) DESC LIMIT } } } + + public static String convertFloatListToBase64(List floatList) { + ByteBuffer byteBuffer = ByteBuffer.allocate(floatList.size() * 4); + byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + for (Float f : floatList) { + byteBuffer.putFloat(f); + } + byte[] bytes = byteBuffer.array(); + return Base64.getEncoder().encodeToString(bytes); + } + + public static String convertFloatListToString(List floatList) { + return "[" + + floatList.stream() + .map(f -> BigDecimal.valueOf(f).toPlainString()) + .collect(Collectors.joining(",")) + + "]"; + } } diff --git a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/TextCompletionsIT.java b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/TextCompletionsIT.java index b6497eb32..933f3dd5a 100644 --- a/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/TextCompletionsIT.java +++ b/langstream-runtime/langstream-runtime-impl/src/test/java/ai/langstream/kafka/TextCompletionsIT.java @@ -57,7 +57,7 @@ public void testTextCompletionsWithLogProbs(WireMockRuntimeInfo vmRuntimeInfo) String model = "gpt-3.5-turbo-instruct"; stubFor( - post("/openai/deployments/gpt-3.5-turbo-instruct/completions?api-version=2023-08-01-preview") + post("/openai/deployments/gpt-3.5-turbo-instruct/completions?api-version=2024-03-01-preview") .withRequestBody( equalTo( "{\"prompt\":[\"What can you tell me about the car ?\"],\"logprobs\":5,\"stream\":true}")) diff --git a/pom.xml b/pom.xml index 88db9e0f7..91eadb88a 100644 --- a/pom.xml +++ b/pom.xml @@ -35,8 +35,8 @@ 7.4.0 1.10.2 - 1.0.0-beta.4 - 1.2.16 + 1.0.0-beta.8 + 1.2.24 4.1.97.Final 2.0.61.Final