From 7296801e2558869fc2c5194a9f43c1b77531981d Mon Sep 17 00:00:00 2001 From: Ido Berkovich Date: Wed, 25 Dec 2024 09:28:30 +0200 Subject: [PATCH] OPIK-610 move retry logic to ChatCompletionService --- .../opik/domain/ChatCompletionService.java | 25 +++++++++++++++++-- .../llmproviders/LlmProviderFactory.java | 17 +------------ .../opik/domain/llmproviders/OpenAi.java | 7 ++---- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java index a8b5fd33fb..06e98e8bc2 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java @@ -2,30 +2,41 @@ import com.comet.opik.domain.llmproviders.LlmProviderFactory; import com.comet.opik.domain.llmproviders.LlmProviderStreamHandler; +import com.comet.opik.infrastructure.LlmProviderClientConfig; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import dev.langchain4j.internal.RetryUtils; import jakarta.inject.Inject; import jakarta.inject.Singleton; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.glassfish.jersey.server.ChunkedOutput; +import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; + +import java.util.Optional; @Singleton @Slf4j public class ChatCompletionService { + private final LlmProviderClientConfig llmProviderClientConfig; private final LlmProviderFactory llmProviderFactory; private final LlmProviderStreamHandler streamHandler; + private final RetryUtils.RetryPolicy retryPolicy; @Inject - public ChatCompletionService(LlmProviderFactory llmProviderFactory, LlmProviderStreamHandler streamHandler) { + public ChatCompletionService( + @NonNull @Config LlmProviderClientConfig llmProviderClientConfig, + LlmProviderFactory llmProviderFactory, LlmProviderStreamHandler streamHandler) { + this.llmProviderClientConfig = llmProviderClientConfig; this.llmProviderFactory = llmProviderFactory; this.streamHandler = streamHandler; + this.retryPolicy = newRetryPolicy(); } public ChatCompletionResponse create(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) { log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); var llmProviderClient = llmProviderFactory.getService(workspaceId, request.model()); - var chatCompletionResponse = llmProviderClient.generate(request, workspaceId); + var chatCompletionResponse = retryPolicy.withRetry(() -> llmProviderClient.generate(request, workspaceId)); log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); return chatCompletionResponse; } @@ -38,4 +49,14 @@ public ChunkedOutput createAndStreamResponse( log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); return chunkedOutput; } + + private RetryUtils.RetryPolicy newRetryPolicy() { + var retryPolicyBuilder = RetryUtils.retryPolicyBuilder(); + Optional.ofNullable(llmProviderClientConfig.getMaxAttempts()).ifPresent(retryPolicyBuilder::maxAttempts); + Optional.ofNullable(llmProviderClientConfig.getJitterScale()).ifPresent(retryPolicyBuilder::jitterScale); + Optional.ofNullable(llmProviderClientConfig.getBackoffExp()).ifPresent(retryPolicyBuilder::backoffExp); + return retryPolicyBuilder + .delayMillis(llmProviderClientConfig.getDelayMillis()) + .build(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java index d72f7e86df..86c1931c67 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java @@ -4,20 +4,16 @@ import com.comet.opik.domain.LlmProviderApiKeyService; import com.comet.opik.infrastructure.EncryptionUtils; import com.comet.opik.infrastructure.LlmProviderClientConfig; -import dev.langchain4j.internal.RetryUtils; import jakarta.inject.Inject; import jakarta.inject.Singleton; import jakarta.ws.rs.BadRequestException; import lombok.NonNull; import ru.vyarus.dropwizard.guice.module.yaml.bind.Config; -import java.util.Optional; - @Singleton public class LlmProviderFactory { private final LlmProviderClientConfig llmProviderClientConfig; private final LlmProviderApiKeyService llmProviderApiKeyService; - private final RetryUtils.RetryPolicy retryPolicy; @Inject public LlmProviderFactory( @@ -25,14 +21,13 @@ public LlmProviderFactory( @NonNull LlmProviderApiKeyService llmProviderApiKeyService) { this.llmProviderApiKeyService = llmProviderApiKeyService; this.llmProviderClientConfig = llmProviderClientConfig; - this.retryPolicy = newRetryPolicy(); } public LlmProviderService getService(@NonNull String workspaceId, @NonNull String model) { var llmProvider = getLlmProvider(model); if (llmProvider == LlmProvider.OPEN_AI) { var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider)); - return new OpenAi(llmProviderClientConfig, retryPolicy, apiKey); + return new OpenAi(llmProviderClientConfig, apiKey); } throw new IllegalArgumentException("not supported provider " + llmProvider); @@ -60,14 +55,4 @@ private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) { llmProvider.getValue()))) .apiKey(); } - - private RetryUtils.RetryPolicy newRetryPolicy() { - var retryPolicyBuilder = RetryUtils.retryPolicyBuilder(); - Optional.ofNullable(llmProviderClientConfig.getMaxAttempts()).ifPresent(retryPolicyBuilder::maxAttempts); - Optional.ofNullable(llmProviderClientConfig.getJitterScale()).ifPresent(retryPolicyBuilder::jitterScale); - Optional.ofNullable(llmProviderClientConfig.getBackoffExp()).ifPresent(retryPolicyBuilder::backoffExp); - return retryPolicyBuilder - .delayMillis(llmProviderClientConfig.getDelayMillis()) - .build(); - } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java index 6b486247b5..7bd16a463b 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/OpenAi.java @@ -5,7 +5,6 @@ import dev.ai4j.openai4j.OpenAiHttpException; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; -import dev.langchain4j.internal.RetryUtils; import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.inject.Inject; import jakarta.ws.rs.ClientErrorException; @@ -23,13 +22,11 @@ public class OpenAi implements LlmProviderService { private static final String UNEXPECTED_ERROR_CALLING_LLM_PROVIDER = "Unexpected error calling LLM provider"; private final LlmProviderClientConfig llmProviderClientConfig; - private final RetryUtils.RetryPolicy retryPolicy; private final OpenAiClient openAiClient; @Inject - public OpenAi(LlmProviderClientConfig llmProviderClientConfig, RetryUtils.RetryPolicy retryPolicy, String apiKey) { + public OpenAi(LlmProviderClientConfig llmProviderClientConfig, String apiKey) { this.llmProviderClientConfig = llmProviderClientConfig; - this.retryPolicy = retryPolicy; this.openAiClient = newOpenAiClient(apiKey); } @@ -38,7 +35,7 @@ public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @ log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); ChatCompletionResponse chatCompletionResponse; try { - chatCompletionResponse = retryPolicy.withRetry(() -> openAiClient.chatCompletion(request).execute()); + chatCompletionResponse = openAiClient.chatCompletion(request).execute(); } catch (RuntimeException runtimeException) { log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, runtimeException); if (runtimeException.getCause() instanceof OpenAiHttpException openAiHttpException) {