From d469df68feb3a1228711e2a7ffee91f7d0383779 Mon Sep 17 00:00:00 2001 From: Ido Berkovich Date: Wed, 25 Dec 2024 09:58:18 +0200 Subject: [PATCH] OPIK-610 move chucked output logic to ChatCompletionService --- .../opik/domain/ChatCompletionService.java | 62 +++++++++++++++++-- .../llmproviders/LlmProviderService.java | 12 +++- .../LlmProviderStreamHandler.java | 51 --------------- .../opik/domain/llmproviders/OpenAi.java | 36 ++++++----- 4 files changed, 84 insertions(+), 77 deletions(-) delete mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderStreamHandler.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java index 06e98e8bc..395e4f4fd 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java @@ -1,11 +1,12 @@ package com.comet.opik.domain; import com.comet.opik.domain.llmproviders.LlmProviderFactory; -import com.comet.opik.domain.llmproviders.LlmProviderStreamHandler; import com.comet.opik.infrastructure.LlmProviderClientConfig; +import com.comet.opik.utils.JsonUtils; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.langchain4j.internal.RetryUtils; +import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.inject.Inject; import jakarta.inject.Singleton; import lombok.NonNull; @@ -13,23 +14,26 @@ import org.glassfish.jersey.server.ChunkedOutput; import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; @Singleton @Slf4j public class ChatCompletionService { + private static final String UNEXPECTED_ERROR_CALLING_LLM_PROVIDER = "Unexpected error calling LLM provider"; + private final LlmProviderClientConfig llmProviderClientConfig; private final LlmProviderFactory llmProviderFactory; - private final LlmProviderStreamHandler streamHandler; private final RetryUtils.RetryPolicy retryPolicy; @Inject public ChatCompletionService( - @NonNull @Config LlmProviderClientConfig llmProviderClientConfig, - LlmProviderFactory llmProviderFactory, LlmProviderStreamHandler streamHandler) { + @NonNull @Config LlmProviderClientConfig llmProviderClientConfig, LlmProviderFactory llmProviderFactory) { this.llmProviderClientConfig = llmProviderClientConfig; this.llmProviderFactory = llmProviderFactory; - this.streamHandler = streamHandler; this.retryPolicy = newRetryPolicy(); } @@ -45,7 +49,14 @@ public ChunkedOutput createAndStreamResponse( @NonNull ChatCompletionRequest request, @NonNull String workspaceId) { log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); var llmProviderClient = llmProviderFactory.getService(workspaceId, request.model()); - var chunkedOutput = llmProviderClient.generateStream(request, workspaceId, streamHandler); + + var chunkedOutput = new ChunkedOutput(String.class, "\r\n"); + llmProviderClient.generateStream( + request, + workspaceId, + getMessageHandler(chunkedOutput), + getCloseHandler(chunkedOutput), + getErrorHandler(chunkedOutput, llmProviderClient::mapError)); log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); return chunkedOutput; } @@ -59,4 +70,43 @@ private RetryUtils.RetryPolicy newRetryPolicy() { .delayMillis(llmProviderClientConfig.getDelayMillis()) .build(); } + + private Consumer getMessageHandler(ChunkedOutput chunkedOutput) { + return item -> { + if (chunkedOutput.isClosed()) { + log.warn("Output stream is already closed"); + return; + } + try { + chunkedOutput.write(JsonUtils.writeValueAsString(item)); + } catch (IOException ioException) { + throw new UncheckedIOException(ioException); + } + }; + } + + private Runnable getCloseHandler(ChunkedOutput chunkedOutput) { + return () -> { + try { + chunkedOutput.close(); + } catch (IOException ioException) { + log.error("Failed to close output stream", ioException); + } + }; + } + + private Consumer getErrorHandler( + ChunkedOutput chunkedOutput, Function errorMapper) { + return throwable -> { + log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable); + + var errorMessage = errorMapper.apply(throwable); + try { + getMessageHandler(chunkedOutput).accept(errorMessage); + } catch (UncheckedIOException uncheckedIOException) { + log.error("Failed to stream error message to client", uncheckedIOException); + } + getCloseHandler(chunkedOutput).run(); + }; + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java index aeacd1b33..375c65c66 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderService.java @@ -2,16 +2,22 @@ import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import io.dropwizard.jersey.errors.ErrorMessage; import lombok.NonNull; -import org.glassfish.jersey.server.ChunkedOutput; + +import java.util.function.Consumer; public interface LlmProviderService { ChatCompletionResponse generate( @NonNull ChatCompletionRequest request, @NonNull String workspaceId); - ChunkedOutput generateStream( + void generateStream( @NonNull ChatCompletionRequest request, @NonNull String workspaceId, - @NonNull LlmProviderStreamHandler streamHandler); + @NonNull Consumer handleMessage, + @NonNull Runnable handleClose, + @NonNull Consumer handleError); + + ErrorMessage mapError(Throwable throwable); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderStreamHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderStreamHandler.java deleted file mode 100644 index 97744ba1d..000000000 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderStreamHandler.java +++ /dev/null @@ -1,51 +0,0 @@ -package com.comet.opik.domain.llmproviders; - -import com.comet.opik.utils.JsonUtils; -import io.dropwizard.jersey.errors.ErrorMessage; -import lombok.extern.slf4j.Slf4j; -import org.glassfish.jersey.server.ChunkedOutput; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.util.function.Consumer; -import java.util.function.Function; - -@Slf4j -public class LlmProviderStreamHandler { - private static final String UNEXPECTED_ERROR_CALLING_LLM_PROVIDER = "Unexpected error calling LLM provider"; - - public void handleMessage(Object item, ChunkedOutput chunkedOutput) { - if (chunkedOutput.isClosed()) { - log.warn("Output stream is already closed"); - return; - } - try { - chunkedOutput.write(JsonUtils.writeValueAsString(item)); - } catch (IOException ioException) { - throw new UncheckedIOException(ioException); - } - } - - public void handleClose(ChunkedOutput chunkedOutput) { - try { - chunkedOutput.close(); - } catch (IOException ioException) { - log.error("Failed to close output stream", ioException); - } - } - - public Consumer getErrorHandler( - Function mapper, ChunkedOutput chunkedOutput) { - return throwable -> { - log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable); - - var errorMessage = mapper.apply(throwable); - try { - handleMessage(errorMessage, chunkedOutput); - } catch (UncheckedIOException uncheckedIOException) { - log.error("Failed to stream error message to client", uncheckedIOException); - } - handleClose(chunkedOutput); - }; - } -} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java index 7bd16a463..accb49184 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java @@ -13,9 +13,9 @@ import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; -import org.glassfish.jersey.server.ChunkedOutput; import java.util.Optional; +import java.util.function.Consumer; @Slf4j public class OpenAi implements LlmProviderService { @@ -51,18 +51,28 @@ public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @ } @Override - public ChunkedOutput generateStream(@NonNull ChatCompletionRequest request, @NonNull String workspaceId, - @NonNull LlmProviderStreamHandler streamHandler) { + public void generateStream( + @NonNull ChatCompletionRequest request, + @NonNull String workspaceId, + @NonNull Consumer handleMessage, + @NonNull Runnable handleClose, + @NonNull Consumer handleError) { log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); - var chunkedOutput = new ChunkedOutput(String.class, "\r\n"); openAiClient.chatCompletion(request) - .onPartialResponse( - chatCompletionResponse -> streamHandler.handleMessage(chatCompletionResponse, chunkedOutput)) - .onComplete(() -> streamHandler.handleClose(chunkedOutput)) - .onError(streamHandler.getErrorHandler(this::errorMapper, chunkedOutput)) + .onPartialResponse(handleMessage) + .onComplete(handleClose) + .onError(handleError) .execute(); log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); - return chunkedOutput; + } + + @Override + public ErrorMessage mapError(Throwable throwable) { + if (throwable instanceof OpenAiHttpException openAiHttpException) { + return new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage()); + } + + return new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER); } /** @@ -95,12 +105,4 @@ private OpenAiClient newOpenAiClient(String apiKey) { .openAiApiKey(apiKey) .build(); } - - private ErrorMessage errorMapper(Throwable throwable) { - if (throwable instanceof OpenAiHttpException openAiHttpException) { - return new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage()); - } - - return new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER); - } }