Skip to content

Commit

Permalink
OPIK-546: Implement create chat completions endpoint (#890)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrescrz authored Dec 13, 2024
1 parent 0d63dbf commit 1415da1
Show file tree
Hide file tree
Showing 17 changed files with 541 additions and 144 deletions.
12 changes: 12 additions & 0 deletions apps/opik-backend/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,15 @@ cors:

encryption:
key: ${OPIK_ENCRYPTION_KEY:-'GiTHubiLoVeYouAA'}

llmProviderClient:
maxAttempts: ${LLM_PROVIDER_CLIENT_MAX_ATTEMPTS:-3}
delayMillis: ${LLM_PROVIDER_CLIENT_DELAY_MILLIS:-500}
jitterScale: ${LLM_PROVIDER_CLIENT_JITTER_SCALE:-0.2}
backoffExp: ${LLM_PROVIDER_CLIENT_BACKOFF_EXP:-1.5}
callTimeout: ${LLM_PROVIDER_CLIENT_CALL_TIMEOUT:-60s}
connectTimeout: ${LLM_PROVIDER_CLIENT_CONNECT_TIMEOUT:-60s}
readTimeout: ${LLM_PROVIDER_CLIENT_READ_TIMEOUT:-60s}
writeTimeout: ${LLM_PROVIDER_CLIENT_WRITE_TIMEOUT:-60s}
openApiClient:
url: ${LLM_PROVIDER_CLIENT_WRITE_TIMEOUT:-}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
import com.comet.opik.infrastructure.ratelimit.RateLimitModule;
import com.comet.opik.infrastructure.redis.RedisModule;
import com.comet.opik.utils.JsonBigDecimalDeserializer;
import com.comet.opik.utils.OpenAiMessageJsonDeserializer;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.module.SimpleModule;
import dev.ai4j.openai4j.chat.Message;
import io.dropwizard.configuration.EnvironmentVariableSubstitutor;
import io.dropwizard.configuration.SubstitutingSourceProvider;
import io.dropwizard.core.Application;
Expand Down Expand Up @@ -89,7 +91,9 @@ public void run(OpikConfiguration configuration, Environment environment) {
environment.getObjectMapper().setPropertyNamingStrategy(PropertyNamingStrategies.SnakeCaseStrategy.INSTANCE);
environment.getObjectMapper().configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false);
environment.getObjectMapper()
.registerModule(new SimpleModule().addDeserializer(BigDecimal.class, new JsonBigDecimalDeserializer()));
.registerModule(new SimpleModule()
.addDeserializer(BigDecimal.class, JsonBigDecimalDeserializer.INSTANCE)
.addDeserializer(Message.class, OpenAiMessageJsonDeserializer.INSTANCE));

jersey.property(ServerProperties.RESPONSE_SET_STATUS_OVER_SEND_ERROR, true);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import lombok.Getter;
import lombok.RequiredArgsConstructor;

import java.util.Arrays;

@Getter
@RequiredArgsConstructor
public enum LlmProvider {

@JsonProperty("openai")
OPEN_AI;
OPEN_AI("openai");

@JsonValue
private final String value;

@JsonCreator
public static LlmProvider fromString(String value) {
return Arrays.stream(values())
.filter(llmProvider -> llmProvider.value.equals(value))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("Unknown llm provider '%s'".formatted(value)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Builder;
import lombok.NonNull;

import java.time.Instant;
import java.util.List;
Expand All @@ -21,7 +21,7 @@
public record ProviderApiKey(
@JsonView( {
View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID id,
@JsonView({View.Public.class, View.Write.class}) @NonNull LlmProvider provider,
@JsonView({View.Public.class, View.Write.class}) @NotNull LlmProvider provider,
@JsonView({
View.Write.class}) @NotBlank @JsonDeserialize(using = ProviderApiKeyDeserializer.class) String apiKey,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package com.comet.opik.api.resources.v1.priv;

import com.codahale.metrics.annotation.Timed;
import com.comet.opik.domain.TextStreamer;
import com.comet.opik.domain.ChatCompletionService;
import com.comet.opik.infrastructure.auth.RequestContext;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.shared.CompletionTokensDetails;
import dev.ai4j.openai4j.shared.Usage;
import io.dropwizard.jersey.errors.ErrorMessage;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.ArraySchema;
Expand All @@ -28,10 +26,6 @@
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Flux;

import java.security.SecureRandom;
import java.util.UUID;

@Path("/v1/private/chat/completions")
@Produces(MediaType.APPLICATION_JSON)
Expand All @@ -43,52 +37,34 @@
public class ChatCompletionsResource {

private final @NonNull Provider<RequestContext> requestContextProvider;
private final @NonNull TextStreamer textStreamer;
private final @NonNull SecureRandom secureRandom;
private final @NonNull ChatCompletionService chatCompletionService;

@POST
@Produces({MediaType.SERVER_SENT_EVENTS, MediaType.APPLICATION_JSON})
@Operation(operationId = "getChatCompletions", summary = "Get chat completions", description = "Get chat completions", responses = {
@ApiResponse(responseCode = "501", description = "Chat completions response", content = {
@Operation(operationId = "createChatCompletions", summary = "Create chat completions", description = "Create chat completions", responses = {
@ApiResponse(responseCode = "200", description = "Chat completions response", content = {
@Content(mediaType = "text/event-stream", array = @ArraySchema(schema = @Schema(type = "object", anyOf = {
ChatCompletionResponse.class,
ErrorMessage.class}))),
@Content(mediaType = "application/json", schema = @Schema(implementation = ChatCompletionResponse.class))}),
})
public Response get(
public Response create(
@RequestBody(content = @Content(schema = @Schema(implementation = ChatCompletionRequest.class))) @NotNull @Valid ChatCompletionRequest request) {
var workspaceId = requestContextProvider.get().getWorkspaceId();
String type;
Object entity;
if (Boolean.TRUE.equals(request.stream())) {
log.info("Streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'",
workspaceId, request.model());
type = MediaType.SERVER_SENT_EVENTS;
var flux = Flux.range(0, 10).map(i -> newResponse(request.model()));
entity = textStreamer.getOutputStream(flux);
entity = chatCompletionService.createAndStreamResponse(request, workspaceId);
} else {
log.info("Getting chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
type = MediaType.APPLICATION_JSON;
entity = newResponse(request.model());
entity = chatCompletionService.create(request, workspaceId);
}
var response = Response.status(Response.Status.NOT_IMPLEMENTED).type(type).entity(entity).build();
log.info("Returned chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var response = Response.ok().type(type).entity(entity).build();
log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return response;
}

private ChatCompletionResponse newResponse(String model) {
return ChatCompletionResponse.builder()
.id(UUID.randomUUID().toString())
.created((int) (System.currentTimeMillis() / 1000))
.model(model)
.usage(Usage.builder()
.totalTokens(secureRandom.nextInt())
.promptTokens(secureRandom.nextInt())
.completionTokens(secureRandom.nextInt())
.completionTokensDetails(CompletionTokensDetails.builder()
.reasoningTokens(secureRandom.nextInt())
.build())
.build())
.systemFingerprint(UUID.randomUUID().toString())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package com.comet.opik.domain;

import com.comet.opik.api.LlmProvider;
import com.comet.opik.infrastructure.EncryptionUtils;
import com.comet.opik.infrastructure.LlmProviderClientConfig;
import com.comet.opik.utils.JsonUtils;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.langchain4j.internal.RetryUtils;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.InternalServerErrorException;
import jakarta.ws.rs.ServerErrorException;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ChunkedOutput;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Optional;

@Singleton
@Slf4j
public class ChatCompletionService {

private static final String UNEXPECTED_ERROR_CALLING_LLM_PROVIDER = "Unexpected error calling LLM provider";

private final LlmProviderClientConfig llmProviderClientConfig;
private final LlmProviderApiKeyService llmProviderApiKeyService;
private final RetryUtils.RetryPolicy retryPolicy;

@Inject
public ChatCompletionService(
@NonNull @Config LlmProviderClientConfig llmProviderClientConfig,
@NonNull LlmProviderApiKeyService llmProviderApiKeyService) {
this.llmProviderApiKeyService = llmProviderApiKeyService;
this.llmProviderClientConfig = llmProviderClientConfig;
this.retryPolicy = newRetryPolicy();
}

private RetryUtils.RetryPolicy newRetryPolicy() {
var retryPolicyBuilder = RetryUtils.retryPolicyBuilder();
Optional.ofNullable(llmProviderClientConfig.getMaxAttempts()).ifPresent(retryPolicyBuilder::maxAttempts);
Optional.ofNullable(llmProviderClientConfig.getJitterScale()).ifPresent(retryPolicyBuilder::jitterScale);
Optional.ofNullable(llmProviderClientConfig.getBackoffExp()).ifPresent(retryPolicyBuilder::backoffExp);
return retryPolicyBuilder
.delayMillis(llmProviderClientConfig.getDelayMillis())
.build();
}

public ChatCompletionResponse create(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var openAiClient = getAndConfigureOpenAiClient(request, workspaceId);
ChatCompletionResponse chatCompletionResponse;
try {
chatCompletionResponse = retryPolicy.withRetry(() -> openAiClient.chatCompletion(request).execute());
} catch (RuntimeException runtimeException) {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, runtimeException);
if (runtimeException.getCause() instanceof OpenAiHttpException openAiHttpException) {
if (openAiHttpException.code() >= 400 && openAiHttpException.code() <= 499) {
throw new ClientErrorException(openAiHttpException.getMessage(), openAiHttpException.code());
}
throw new ServerErrorException(openAiHttpException.getMessage(), openAiHttpException.code());
}
throw new InternalServerErrorException(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
}
log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chatCompletionResponse;
}

public ChunkedOutput<String> createAndStreamResponse(
@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var openAiClient = getAndConfigureOpenAiClient(request, workspaceId);
var chunkedOutput = new ChunkedOutput<String>(String.class, "\r\n");
openAiClient.chatCompletion(request)
.onPartialResponse(chatCompletionResponse -> send(chatCompletionResponse, chunkedOutput))
.onComplete(() -> close(chunkedOutput))
.onError(throwable -> handle(throwable, chunkedOutput))
.execute();
log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chunkedOutput;
}

private OpenAiClient getAndConfigureOpenAiClient(ChatCompletionRequest request, String workspaceId) {
var llmProvider = getLlmProvider(request.model());
var encryptedApiKey = getEncryptedApiKey(workspaceId, llmProvider);
return newOpenAiClient(encryptedApiKey);
}

/**
* The agreed requirement is to resolve the LLM provider and its API key based on the model.
* Currently, only OPEN AI is supported, so model param is ignored.
* No further validation is needed on the model, as it's just forwarded in the OPEN AI request and will be rejected
* if not valid.
*/
private LlmProvider getLlmProvider(String model) {
return LlmProvider.OPEN_AI;
}

/**
* Finding API keys isn't paginated at the moment, since only OPEN AI is supported.
* Even in the future, the number of supported LLM providers per workspace is going to be very low.
*/
private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) {
return llmProviderApiKeyService.find(workspaceId).content().stream()
.filter(providerApiKey -> llmProvider.equals(providerApiKey.provider()))
.findFirst()
.orElseThrow(() -> new BadRequestException("API key not configured for LLM provider '%s'".formatted(
llmProvider.getValue())))
.apiKey();
}

/**
* Initially, only OPEN AI is supported, so no need for a more sophisticated client resolution to start with.
* At the moment, openai4j client and also langchain4j wrappers, don't support dynamic API keys. That can imply
* an important performance penalty for next phases. The following options should be evaluated:
* - Cache clients, but can be unsafe.
* - Find and evaluate other clients.
* - Implement our own client.
* TODO as part of : <a href="https://comet-ml.atlassian.net/browse/OPIK-522">OPIK-522</a>
*/
private OpenAiClient newOpenAiClient(String encryptedApiKey) {
var openAiClientBuilder = OpenAiClient.builder();
Optional.ofNullable(llmProviderClientConfig.getOpenApiClient())
.map(LlmProviderClientConfig.OpenApiClientConfig::url)
.ifPresent(baseUrl -> {
if (StringUtils.isNotBlank(baseUrl)) {
openAiClientBuilder.baseUrl(baseUrl);
}
});
Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
.ifPresent(callTimeout -> openAiClientBuilder.callTimeout(callTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getConnectTimeout())
.ifPresent(connectTimeout -> openAiClientBuilder.connectTimeout(connectTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getReadTimeout())
.ifPresent(readTimeout -> openAiClientBuilder.readTimeout(readTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getWriteTimeout())
.ifPresent(writeTimeout -> openAiClientBuilder.writeTimeout(writeTimeout.toJavaDuration()));
return openAiClientBuilder
.openAiApiKey(EncryptionUtils.decrypt(encryptedApiKey))
.build();
}

private void send(Object item, ChunkedOutput<String> chunkedOutput) {
if (chunkedOutput.isClosed()) {
log.warn("Output stream is already closed");
return;
}
try {
chunkedOutput.write(JsonUtils.writeValueAsString(item));
} catch (IOException ioException) {
throw new UncheckedIOException(ioException);
}
}

private void handle(Throwable throwable, ChunkedOutput<String> chunkedOutput) {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable);
var errorMessage = new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
if (throwable instanceof OpenAiHttpException openAiHttpException) {
errorMessage = new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage());
}
try {
send(errorMessage, chunkedOutput);
} catch (UncheckedIOException uncheckedIOException) {
log.error("Failed to stream error message to client", uncheckedIOException);
}
close(chunkedOutput);
}

private void close(ChunkedOutput<String> chunkedOutput) {
try {
chunkedOutput.close();
} catch (IOException ioException) {
log.error("Failed to close output stream", ioException);
}
}
}
Loading

0 comments on commit 1415da1

Please sign in to comment.