Skip to content

Commit

Permalink
OPIK-610 move chucked output logic to ChatCompletionService
Browse files Browse the repository at this point in the history
  • Loading branch information
idoberko2 committed Dec 25, 2024
1 parent 7296801 commit d469df6
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 77 deletions.
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
package com.comet.opik.domain;

import com.comet.opik.domain.llmproviders.LlmProviderFactory;
import com.comet.opik.domain.llmproviders.LlmProviderStreamHandler;
import com.comet.opik.infrastructure.LlmProviderClientConfig;
import com.comet.opik.utils.JsonUtils;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.langchain4j.internal.RetryUtils;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Function;

@Singleton
@Slf4j
public class ChatCompletionService {
private static final String UNEXPECTED_ERROR_CALLING_LLM_PROVIDER = "Unexpected error calling LLM provider";

private final LlmProviderClientConfig llmProviderClientConfig;
private final LlmProviderFactory llmProviderFactory;
private final LlmProviderStreamHandler streamHandler;
private final RetryUtils.RetryPolicy retryPolicy;

@Inject
public ChatCompletionService(
@NonNull @Config LlmProviderClientConfig llmProviderClientConfig,
LlmProviderFactory llmProviderFactory, LlmProviderStreamHandler streamHandler) {
@NonNull @Config LlmProviderClientConfig llmProviderClientConfig, LlmProviderFactory llmProviderFactory) {
this.llmProviderClientConfig = llmProviderClientConfig;
this.llmProviderFactory = llmProviderFactory;
this.streamHandler = streamHandler;
this.retryPolicy = newRetryPolicy();
}

Expand All @@ -45,7 +49,14 @@ public ChunkedOutput<String> createAndStreamResponse(
@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var llmProviderClient = llmProviderFactory.getService(workspaceId, request.model());
var chunkedOutput = llmProviderClient.generateStream(request, workspaceId, streamHandler);

var chunkedOutput = new ChunkedOutput<String>(String.class, "\r\n");
llmProviderClient.generateStream(
request,
workspaceId,
getMessageHandler(chunkedOutput),
getCloseHandler(chunkedOutput),
getErrorHandler(chunkedOutput, llmProviderClient::mapError));
log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chunkedOutput;
}
Expand All @@ -59,4 +70,43 @@ private RetryUtils.RetryPolicy newRetryPolicy() {
.delayMillis(llmProviderClientConfig.getDelayMillis())
.build();
}

private <T> Consumer<T> getMessageHandler(ChunkedOutput<String> chunkedOutput) {
return item -> {
if (chunkedOutput.isClosed()) {
log.warn("Output stream is already closed");
return;
}
try {
chunkedOutput.write(JsonUtils.writeValueAsString(item));
} catch (IOException ioException) {
throw new UncheckedIOException(ioException);
}
};
}

private Runnable getCloseHandler(ChunkedOutput<String> chunkedOutput) {
return () -> {
try {
chunkedOutput.close();
} catch (IOException ioException) {
log.error("Failed to close output stream", ioException);
}
};
}

private Consumer<Throwable> getErrorHandler(
ChunkedOutput<String> chunkedOutput, Function<Throwable, ErrorMessage> errorMapper) {
return throwable -> {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable);

var errorMessage = errorMapper.apply(throwable);
try {
getMessageHandler(chunkedOutput).accept(errorMessage);
} catch (UncheckedIOException uncheckedIOException) {
log.error("Failed to stream error message to client", uncheckedIOException);
}
getCloseHandler(chunkedOutput).run();
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@

import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import io.dropwizard.jersey.errors.ErrorMessage;
import lombok.NonNull;
import org.glassfish.jersey.server.ChunkedOutput;

import java.util.function.Consumer;

public interface LlmProviderService {
ChatCompletionResponse generate(
@NonNull ChatCompletionRequest request,
@NonNull String workspaceId);

ChunkedOutput<String> generateStream(
void generateStream(
@NonNull ChatCompletionRequest request,
@NonNull String workspaceId,
@NonNull LlmProviderStreamHandler streamHandler);
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose,
@NonNull Consumer<Throwable> handleError);

ErrorMessage mapError(Throwable throwable);
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ChunkedOutput;

import java.util.Optional;
import java.util.function.Consumer;

@Slf4j
public class OpenAi implements LlmProviderService {
Expand Down Expand Up @@ -51,18 +51,28 @@ public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @
}

@Override
public ChunkedOutput<String> generateStream(@NonNull ChatCompletionRequest request, @NonNull String workspaceId,
@NonNull LlmProviderStreamHandler streamHandler) {
public void generateStream(
@NonNull ChatCompletionRequest request,
@NonNull String workspaceId,
@NonNull Consumer<ChatCompletionResponse> handleMessage,
@NonNull Runnable handleClose,
@NonNull Consumer<Throwable> handleError) {
log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var chunkedOutput = new ChunkedOutput<String>(String.class, "\r\n");
openAiClient.chatCompletion(request)
.onPartialResponse(
chatCompletionResponse -> streamHandler.handleMessage(chatCompletionResponse, chunkedOutput))
.onComplete(() -> streamHandler.handleClose(chunkedOutput))
.onError(streamHandler.getErrorHandler(this::errorMapper, chunkedOutput))
.onPartialResponse(handleMessage)
.onComplete(handleClose)
.onError(handleError)
.execute();
log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chunkedOutput;
}

@Override
public ErrorMessage mapError(Throwable throwable) {
if (throwable instanceof OpenAiHttpException openAiHttpException) {
return new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage());
}

return new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
}

/**
Expand Down Expand Up @@ -95,12 +105,4 @@ private OpenAiClient newOpenAiClient(String apiKey) {
.openAiApiKey(apiKey)
.build();
}

private ErrorMessage errorMapper(Throwable throwable) {
if (throwable instanceof OpenAiHttpException openAiHttpException) {
return new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage());
}

return new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
}
}

0 comments on commit d469df6

Please sign in to comment.