Skip to content
This repository has been archived by the owner on Aug 25, 2024. It is now read-only.

Commit

Permalink
Openai embedding dimensions and newer Azure SDK (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdbartholomew authored Jun 13, 2024
1 parent eca29ef commit 5927f8c
Show file tree
Hide file tree
Showing 17 changed files with 374 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ pipeline:
type: "compute-ai-embeddings"
output: "output-topic"
configuration:
model: "${secrets.open-ai.embeddings-model}" # This needs to match the name of the model deployment, not the base model
model: "text-embedding-3-large" # This needs to match the name of the model deployment, not the base model
embeddings-field: "value.embeddings"
text: "{{ value.text }}"
dimensions: 256
batch-size: 10
# this is in milliseconds. It is important to take this value into consideration when using this agent in the chat response pipeline
# in fact this value impacts the latency of the response
Expand Down
4 changes: 3 additions & 1 deletion examples/applications/openai-completions/pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,7 @@ pipeline:
# eventually we want to send bigger messages to reduce the overhead of each message on the topic
min-chunks-per-message: 10
messages:
- role: system
content: "You are helpful assistant."
- role: user
content: "You are a helpful assistant. Below you can find a question from the user. Please try to help them the best way you can.\n\n{{ value.question}}"
content: "{{ value.question}}"
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import ai.langstream.api.runner.code.MetricsReporter;
import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatRole;
import com.azure.ai.openai.models.ChatRequestAssistantMessage;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.CompletionsLogProbabilityModel;
import com.azure.ai.openai.models.CompletionsOptions;
Expand Down Expand Up @@ -124,16 +127,29 @@ public CompletableFuture<ChatCompletions> getChatCompletions(
StreamingChunksConsumer streamingChunksConsumer,
Map<String, Object> options) {
int minChunksPerMessage = getInteger("min-chunks-per-message", 20, options);

List<ChatRequestMessage> chatMessages =
messages.stream()
.map(
message -> {
switch (message.getRole()) {
case "system":
return new ChatRequestSystemMessage(
message.getContent());
case "user":
return new ChatRequestUserMessage(message.getContent());
case "assistant":
return new ChatRequestAssistantMessage(
message.getContent());
default:
throw new IllegalArgumentException(
"Unknown chat role: " + message.getRole());
}
})
.collect(Collectors.toList());

ChatCompletionsOptions chatCompletionsOptions =
new ChatCompletionsOptions(
messages.stream()
.map(
message ->
new com.azure.ai.openai.models.ChatMessage(
ChatRole.fromString(
message.getRole()),
message.getContent()))
.collect(Collectors.toList()))
new ChatCompletionsOptions(chatMessages)
.setMaxTokens(getInteger("max-tokens", null, options))
.setTemperature(getDouble("temperature", null, options))
.setTopP(getDouble("top-p", null, options))
Expand All @@ -143,6 +159,7 @@ public CompletableFuture<ChatCompletions> getChatCompletions(
.setStop((List<String>) options.get("stop"))
.setPresencePenalty(getDouble("presence-penalty", null, options))
.setFrequencyPenalty(getDouble("frequency-penalty", null, options));

ChatCompletions result = new ChatCompletions();
chatNumCalls.count(1);
// this is the default behavior, as it is async
Expand Down Expand Up @@ -211,7 +228,7 @@ public CompletableFuture<ChatCompletions> getChatCompletions(
}

private static ChatMessage convertMessage(com.azure.ai.openai.models.ChatChoice c) {
com.azure.ai.openai.models.ChatMessage message = c.getMessage();
com.azure.ai.openai.models.ChatResponseMessage message = c.getMessage();
if (message == null) {
message = c.getDelta();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class OpenAIEmbeddingsService implements EmbeddingsService {

private final OpenAIAsyncClient openAIClient;
private final String model;
private final Integer dimensions;

private final MetricsReporter.Counter totalTokens;
private final MetricsReporter.Counter promptTokens;
Expand All @@ -38,9 +39,13 @@ public class OpenAIEmbeddingsService implements EmbeddingsService {
private final MetricsReporter.Counter numErrors;

public OpenAIEmbeddingsService(
OpenAIAsyncClient openAIClient, String model, MetricsReporter metricsReporter) {
OpenAIAsyncClient openAIClient,
String model,
MetricsReporter metricsReporter,
Integer dimensions) {
this.openAIClient = openAIClient;
this.model = model;
this.dimensions = dimensions;
this.totalTokens =
metricsReporter.counter(
"openai_embeddings_total_tokens",
Expand All @@ -65,8 +70,13 @@ public OpenAIEmbeddingsService(
public CompletableFuture<List<List<Double>>> computeEmbeddings(List<String> texts) {
try {
EmbeddingsOptions embeddingsOptions = new EmbeddingsOptions(texts);
if (dimensions > 0) {
log.debug("Setting embedding dimensions to {}", dimensions);
embeddingsOptions.setDimensions(dimensions);
}
numCalls.count(1);
numTexts.count(texts.size());

CompletableFuture<List<List<Double>>> result =
openAIClient
.getEmbeddings(model, embeddingsOptions)
Expand All @@ -78,12 +88,12 @@ public CompletableFuture<List<List<Double>>> computeEmbeddings(List<String> text
promptTokens.count(usage.getPromptTokens());
return embeddings.getData().stream()
.map(EmbeddingItem::getEmbedding)
.map(this::convertToDoubleList)
.collect(Collectors.toList());
});

result.exceptionally(
err -> {
// API call error
numErrors.count(1);
return null;
});
Expand All @@ -94,4 +104,10 @@ public CompletableFuture<List<List<Double>>> computeEmbeddings(List<String> text
return CompletableFuture.failedFuture(err);
}
}

private List<Double> convertToDoubleList(List<Float> floatList) {
// Log the length of the float list
log.debug("Float list length: {}", floatList.size());
return floatList.stream().map(Float::doubleValue).collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,7 @@ public class ComputeAIEmbeddingsConfig extends StepConfig {

@JsonProperty(value = "model-url")
String modelUrl;

@JsonProperty(value = "dimensions")
private int dimensions = 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ public CompletionsService getCompletionsService(Map<String, Object> additionalCo
@Override
public EmbeddingsService getEmbeddingsService(Map<String, Object> additionalConfiguration) {
String model = (String) additionalConfiguration.get("model");
return new OpenAIEmbeddingsService(client, model, metricsReporter);
// Get dimensions from the configuration
Integer dimensions = (Integer) additionalConfiguration.get("dimensions");
return new OpenAIEmbeddingsService(client, model, metricsReporter, dimensions);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class OpenAIProviderTest {
void testStreamingChatCompletion(WireMockRuntimeInfo vmRuntimeInfo) throws Exception {
resetWiremockStubs(vmRuntimeInfo);
stubFor(
post("/openai/deployments/gpt-35-turbo/chat/completions?api-version=2023-08-01-preview")
post("/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-03-01-preview")
.willReturn(
okJson(
"""
Expand Down Expand Up @@ -127,7 +127,7 @@ public void consumeChunk(
void testStreamingTextCompletion(WireMockRuntimeInfo vmRuntimeInfo) throws Exception {
resetWiremockStubs(vmRuntimeInfo);
stubFor(
post("/openai/deployments/gpt-35-turbo-instruct/completions?api-version=2023-08-01-preview")
post("/openai/deployments/gpt-35-turbo-instruct/completions?api-version=2024-03-01-preview")
.withRequestBody(
equalTo(
"{\"prompt\":[\"Translate from English to Italian: \\\"I love cars\\\" with quotes\"],\"stream\":true}"))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai.langstream.ai.agents.services.impl;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.models.EmbeddingsOptions;
import com.azure.ai.openai.models.EmbeddingsUsage;
import com.azure.core.credential.AzureKeyCredential;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

public class OpenAITest {

private OpenAIAsyncClient openAIClient;

@BeforeEach
void setup() {
openAIClient =
new OpenAIClientBuilder()
.credential(new AzureKeyCredential("YOUR_OPENAI_KEY"))
.buildAsyncClient();
}

@Disabled
@Test
void testRealChatCompletions() throws Exception {
List<ChatRequestMessage> chatMessages = new ArrayList<>();
chatMessages.add(new ChatRequestUserMessage("Name the US presidents of the 20th century"));

ChatCompletionsOptions options = new ChatCompletionsOptions(chatMessages);

CompletableFuture<Void> completableFuture = new CompletableFuture<>();

openAIClient
.getChatCompletionsStream("gpt-3.5-turbo", options)
.doOnNext(
chatCompletions -> {
String response =
chatCompletions.getChoices().get(0).getDelta().getContent();
System.out.println("Response: " + response);
if (response == null) {
completableFuture.complete(null);
}
})
.doOnError(completableFuture::completeExceptionally)
.subscribe();

completableFuture.join();
}

@Disabled
@Test
void testRealEmbeddings() throws Exception {
String azureOpenaiKey = "";
String deploymentOrModelId = "text-embedding-ada-002";

OpenAIClient client =
new OpenAIClientBuilder()
.credential(new AzureKeyCredential(azureOpenaiKey))
.buildClient();

EmbeddingsOptions embeddingsOptions =
new EmbeddingsOptions(Arrays.asList("Your text string goes here"));

Embeddings embeddings = client.getEmbeddings(deploymentOrModelId, embeddingsOptions);

for (EmbeddingItem item : embeddings.getData()) {
System.out.printf("Index: %d.%n", item.getPromptIndex());
System.out.println(
"Embedding as base64 encoded string: " + item.getEmbeddingAsString());
System.out.println("Embedding as list of floats: ");
for (Float embedding : item.getEmbedding()) {
System.out.printf("%f;", embedding);
}
}

EmbeddingsUsage usage = embeddings.getUsage();
System.out.printf(
"Usage: number of prompt token is %d and number of total tokens in request and response is %d.%n",
usage.getPromptTokens(), usage.getTotalTokens());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.models.ChatCompletions;
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.datastax.oss.streaming.ai.datasource.QueryStepDataSource;
import com.datastax.oss.streaming.ai.services.OpenAIServiceProvider;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand Down Expand Up @@ -221,7 +223,18 @@ void testChatCompletions() throws Exception {
ArgumentCaptor.forClass(ChatCompletionsOptions.class);
verify(client).getChatCompletionsStream(eq("test-model"), captor.capture());

assertEquals(captor.getValue().getMessages().get(0).getContent(), "value1 key2");
captor.getAllValues().forEach(System.out::println);

List<ChatRequestMessage> messages = captor.getValue().getMessages();
ChatRequestMessage firstMessage = messages.get(0);

if (firstMessage instanceof ChatRequestUserMessage) {
ChatRequestUserMessage userMessage = (ChatRequestUserMessage) firstMessage;
String messageContent = userMessage.getContent().toString();
assertEquals("value1 key2", messageContent);
} else {
throw new AssertionError("Expected first message to be of type ChatRequestUserMessage");
}
}

@Test
Expand Down
Loading

0 comments on commit 5927f8c

Please sign in to comment.