Skip to content

Commit

Permalink
OPIK-610 move retry logic to ChatCompletionService
Browse files Browse the repository at this point in the history
  • Loading branch information
idoberko2 committed Dec 25, 2024
1 parent 3bfb047 commit 7296801
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,41 @@

import com.comet.opik.domain.llmproviders.LlmProviderFactory;
import com.comet.opik.domain.llmproviders.LlmProviderStreamHandler;
import com.comet.opik.infrastructure.LlmProviderClientConfig;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.langchain4j.internal.RetryUtils;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

import java.util.Optional;

@Singleton
@Slf4j
public class ChatCompletionService {
private final LlmProviderClientConfig llmProviderClientConfig;
private final LlmProviderFactory llmProviderFactory;
private final LlmProviderStreamHandler streamHandler;
private final RetryUtils.RetryPolicy retryPolicy;

@Inject
public ChatCompletionService(LlmProviderFactory llmProviderFactory, LlmProviderStreamHandler streamHandler) {
public ChatCompletionService(
@NonNull @Config LlmProviderClientConfig llmProviderClientConfig,
LlmProviderFactory llmProviderFactory, LlmProviderStreamHandler streamHandler) {
this.llmProviderClientConfig = llmProviderClientConfig;
this.llmProviderFactory = llmProviderFactory;
this.streamHandler = streamHandler;
this.retryPolicy = newRetryPolicy();
}

public ChatCompletionResponse create(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var llmProviderClient = llmProviderFactory.getService(workspaceId, request.model());
var chatCompletionResponse = llmProviderClient.generate(request, workspaceId);
var chatCompletionResponse = retryPolicy.withRetry(() -> llmProviderClient.generate(request, workspaceId));
log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chatCompletionResponse;
}
Expand All @@ -38,4 +49,14 @@ public ChunkedOutput<String> createAndStreamResponse(
log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chunkedOutput;
}

private RetryUtils.RetryPolicy newRetryPolicy() {
var retryPolicyBuilder = RetryUtils.retryPolicyBuilder();
Optional.ofNullable(llmProviderClientConfig.getMaxAttempts()).ifPresent(retryPolicyBuilder::maxAttempts);
Optional.ofNullable(llmProviderClientConfig.getJitterScale()).ifPresent(retryPolicyBuilder::jitterScale);
Optional.ofNullable(llmProviderClientConfig.getBackoffExp()).ifPresent(retryPolicyBuilder::backoffExp);
return retryPolicyBuilder
.delayMillis(llmProviderClientConfig.getDelayMillis())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,30 @@
import com.comet.opik.domain.LlmProviderApiKeyService;
import com.comet.opik.infrastructure.EncryptionUtils;
import com.comet.opik.infrastructure.LlmProviderClientConfig;
import dev.langchain4j.internal.RetryUtils;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.BadRequestException;
import lombok.NonNull;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

import java.util.Optional;

@Singleton
public class LlmProviderFactory {
private final LlmProviderClientConfig llmProviderClientConfig;
private final LlmProviderApiKeyService llmProviderApiKeyService;
private final RetryUtils.RetryPolicy retryPolicy;

@Inject
public LlmProviderFactory(
@NonNull @Config LlmProviderClientConfig llmProviderClientConfig,
@NonNull LlmProviderApiKeyService llmProviderApiKeyService) {
this.llmProviderApiKeyService = llmProviderApiKeyService;
this.llmProviderClientConfig = llmProviderClientConfig;
this.retryPolicy = newRetryPolicy();
}

public LlmProviderService getService(@NonNull String workspaceId, @NonNull String model) {
var llmProvider = getLlmProvider(model);
if (llmProvider == LlmProvider.OPEN_AI) {
var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider));
return new OpenAi(llmProviderClientConfig, retryPolicy, apiKey);
return new OpenAi(llmProviderClientConfig, apiKey);
}

throw new IllegalArgumentException("not supported provider " + llmProvider);
Expand Down Expand Up @@ -60,14 +55,4 @@ private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) {
llmProvider.getValue())))
.apiKey();
}

private RetryUtils.RetryPolicy newRetryPolicy() {
var retryPolicyBuilder = RetryUtils.retryPolicyBuilder();
Optional.ofNullable(llmProviderClientConfig.getMaxAttempts()).ifPresent(retryPolicyBuilder::maxAttempts);
Optional.ofNullable(llmProviderClientConfig.getJitterScale()).ifPresent(retryPolicyBuilder::jitterScale);
Optional.ofNullable(llmProviderClientConfig.getBackoffExp()).ifPresent(retryPolicyBuilder::backoffExp);
return retryPolicyBuilder
.delayMillis(llmProviderClientConfig.getDelayMillis())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.langchain4j.internal.RetryUtils;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
import jakarta.ws.rs.ClientErrorException;
Expand All @@ -23,13 +22,11 @@ public class OpenAi implements LlmProviderService {
private static final String UNEXPECTED_ERROR_CALLING_LLM_PROVIDER = "Unexpected error calling LLM provider";

private final LlmProviderClientConfig llmProviderClientConfig;
private final RetryUtils.RetryPolicy retryPolicy;
private final OpenAiClient openAiClient;

@Inject
public OpenAi(LlmProviderClientConfig llmProviderClientConfig, RetryUtils.RetryPolicy retryPolicy, String apiKey) {
public OpenAi(LlmProviderClientConfig llmProviderClientConfig, String apiKey) {
this.llmProviderClientConfig = llmProviderClientConfig;
this.retryPolicy = retryPolicy;
this.openAiClient = newOpenAiClient(apiKey);
}

Expand All @@ -38,7 +35,7 @@ public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
ChatCompletionResponse chatCompletionResponse;
try {
chatCompletionResponse = retryPolicy.withRetry(() -> openAiClient.chatCompletion(request).execute());
chatCompletionResponse = openAiClient.chatCompletion(request).execute();
} catch (RuntimeException runtimeException) {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, runtimeException);
if (runtimeException.getCause() instanceof OpenAiHttpException openAiHttpException) {
Expand Down

0 comments on commit 7296801

Please sign in to comment.