Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPIK-546: Implement create chat completions endpoint #890

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions apps/opik-backend/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,15 @@ cors:

encryption:
key: ${OPIK_ENCRYPTION_KEY:-'GiTHubiLoVeYouAA'}

llmProviderClient:
maxAttempts: ${LLM_PROVIDER_CLIENT_MAX_ATTEMPTS:-3}
delayMillis: ${LLM_PROVIDER_CLIENT_DELAY_MILLIS:-500}
jitterScale: ${LLM_PROVIDER_CLIENT_JITTER_SCALE:-0.2}
backoffExp: ${LLM_PROVIDER_CLIENT_BACKOFF_EXP:-1.5}
callTimeout: ${LLM_PROVIDER_CLIENT_CALL_TIMEOUT:-60s}
connectTimeout: ${LLM_PROVIDER_CLIENT_CONNECT_TIMEOUT:-60s}
readTimeout: ${LLM_PROVIDER_CLIENT_READ_TIMEOUT:-60s}
writeTimeout: ${LLM_PROVIDER_CLIENT_WRITE_TIMEOUT:-60s}
openApiClient:
url: ${LLM_PROVIDER_CLIENT_WRITE_TIMEOUT:-}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
import com.comet.opik.infrastructure.ratelimit.RateLimitModule;
import com.comet.opik.infrastructure.redis.RedisModule;
import com.comet.opik.utils.JsonBigDecimalDeserializer;
import com.comet.opik.utils.OpenAiMessageJsonDeserializer;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.module.SimpleModule;
import dev.ai4j.openai4j.chat.Message;
import io.dropwizard.configuration.EnvironmentVariableSubstitutor;
import io.dropwizard.configuration.SubstitutingSourceProvider;
import io.dropwizard.core.Application;
Expand Down Expand Up @@ -89,7 +91,9 @@ public void run(OpikConfiguration configuration, Environment environment) {
environment.getObjectMapper().setPropertyNamingStrategy(PropertyNamingStrategies.SnakeCaseStrategy.INSTANCE);
environment.getObjectMapper().configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false);
environment.getObjectMapper()
.registerModule(new SimpleModule().addDeserializer(BigDecimal.class, new JsonBigDecimalDeserializer()));
.registerModule(new SimpleModule()
.addDeserializer(BigDecimal.class, JsonBigDecimalDeserializer.INSTANCE)
.addDeserializer(Message.class, OpenAiMessageJsonDeserializer.INSTANCE));

jersey.property(ServerProperties.RESPONSE_SET_STATUS_OVER_SEND_ERROR, true);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
package com.comet.opik.api;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import lombok.Getter;
import lombok.RequiredArgsConstructor;

import java.util.Arrays;

@Getter
@RequiredArgsConstructor
public enum LlmProvider {

@JsonProperty("openai")
OPEN_AI;
OPEN_AI("openai");

@JsonValue
private final String value;

@JsonCreator
public static LlmProvider fromString(String value) {
return Arrays.stream(values())
.filter(llmProvider -> llmProvider.value.equals(value))
.findFirst()
.orElseThrow(() -> new IllegalArgumentException("Unknown llm provider '%s'".formatted(value)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import lombok.Builder;
import lombok.NonNull;

import java.time.Instant;
import java.util.List;
Expand All @@ -21,7 +21,7 @@
public record ProviderApiKey(
@JsonView( {
View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID id,
@JsonView({View.Public.class, View.Write.class}) @NonNull LlmProvider provider,
@JsonView({View.Public.class, View.Write.class}) @NotNull LlmProvider provider,
@JsonView({
View.Write.class}) @NotBlank @JsonDeserialize(using = ProviderApiKeyDeserializer.class) String apiKey,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package com.comet.opik.api.resources.v1.priv;

import com.codahale.metrics.annotation.Timed;
import com.comet.opik.domain.TextStreamer;
import com.comet.opik.domain.ChatCompletionService;
import com.comet.opik.infrastructure.auth.RequestContext;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.shared.CompletionTokensDetails;
import dev.ai4j.openai4j.shared.Usage;
import io.dropwizard.jersey.errors.ErrorMessage;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.ArraySchema;
Expand All @@ -28,10 +26,6 @@
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Flux;

import java.security.SecureRandom;
import java.util.UUID;

@Path("/v1/private/chat/completions")
@Produces(MediaType.APPLICATION_JSON)
Expand All @@ -43,52 +37,34 @@
public class ChatCompletionsResource {

private final @NonNull Provider<RequestContext> requestContextProvider;
private final @NonNull TextStreamer textStreamer;
private final @NonNull SecureRandom secureRandom;
private final @NonNull ChatCompletionService chatCompletionService;

@POST
@Produces({MediaType.SERVER_SENT_EVENTS, MediaType.APPLICATION_JSON})
@Operation(operationId = "getChatCompletions", summary = "Get chat completions", description = "Get chat completions", responses = {
@ApiResponse(responseCode = "501", description = "Chat completions response", content = {
@Operation(operationId = "createChatCompletions", summary = "Create chat completions", description = "Create chat completions", responses = {
@ApiResponse(responseCode = "200", description = "Chat completions response", content = {
@Content(mediaType = "text/event-stream", array = @ArraySchema(schema = @Schema(type = "object", anyOf = {
ChatCompletionResponse.class,
ErrorMessage.class}))),
@Content(mediaType = "application/json", schema = @Schema(implementation = ChatCompletionResponse.class))}),
})
public Response get(
public Response create(
@RequestBody(content = @Content(schema = @Schema(implementation = ChatCompletionRequest.class))) @NotNull @Valid ChatCompletionRequest request) {
var workspaceId = requestContextProvider.get().getWorkspaceId();
String type;
Object entity;
if (Boolean.TRUE.equals(request.stream())) {
log.info("Streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'",
workspaceId, request.model());
type = MediaType.SERVER_SENT_EVENTS;
var flux = Flux.range(0, 10).map(i -> newResponse(request.model()));
entity = textStreamer.getOutputStream(flux);
entity = chatCompletionService.createAndStreamResponse(request, workspaceId);
} else {
log.info("Getting chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
type = MediaType.APPLICATION_JSON;
entity = newResponse(request.model());
entity = chatCompletionService.create(request, workspaceId);
}
var response = Response.status(Response.Status.NOT_IMPLEMENTED).type(type).entity(entity).build();
log.info("Returned chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var response = Response.ok().type(type).entity(entity).build();
log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return response;
}

private ChatCompletionResponse newResponse(String model) {
return ChatCompletionResponse.builder()
.id(UUID.randomUUID().toString())
.created((int) (System.currentTimeMillis() / 1000))
.model(model)
.usage(Usage.builder()
.totalTokens(secureRandom.nextInt())
.promptTokens(secureRandom.nextInt())
.completionTokens(secureRandom.nextInt())
.completionTokensDetails(CompletionTokensDetails.builder()
.reasoningTokens(secureRandom.nextInt())
.build())
.build())
.systemFingerprint(UUID.randomUUID().toString())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package com.comet.opik.domain;

import com.comet.opik.api.LlmProvider;
import com.comet.opik.infrastructure.EncryptionUtils;
import com.comet.opik.infrastructure.LlmProviderClientConfig;
import com.comet.opik.utils.JsonUtils;
import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.langchain4j.internal.RetryUtils;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
import jakarta.inject.Singleton;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.ClientErrorException;
import jakarta.ws.rs.InternalServerErrorException;
import jakarta.ws.rs.ServerErrorException;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.glassfish.jersey.server.ChunkedOutput;
import ru.vyarus.dropwizard.guice.module.yaml.bind.Config;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Optional;

@Singleton
@Slf4j
public class ChatCompletionService {

private static final String UNEXPECTED_ERROR_CALLING_LLM_PROVIDER = "Unexpected error calling LLM provider";

private final LlmProviderClientConfig llmProviderClientConfig;
private final LlmProviderApiKeyService llmProviderApiKeyService;
private final RetryUtils.RetryPolicy retryPolicy;

@Inject
public ChatCompletionService(
@NonNull @Config LlmProviderClientConfig llmProviderClientConfig,
@NonNull LlmProviderApiKeyService llmProviderApiKeyService) {
this.llmProviderApiKeyService = llmProviderApiKeyService;
this.llmProviderClientConfig = llmProviderClientConfig;
this.retryPolicy = newRetryPolicy();
}

private RetryUtils.RetryPolicy newRetryPolicy() {
var retryPolicyBuilder = RetryUtils.retryPolicyBuilder();
Optional.ofNullable(llmProviderClientConfig.getMaxAttempts()).ifPresent(retryPolicyBuilder::maxAttempts);
Optional.ofNullable(llmProviderClientConfig.getJitterScale()).ifPresent(retryPolicyBuilder::jitterScale);
Optional.ofNullable(llmProviderClientConfig.getBackoffExp()).ifPresent(retryPolicyBuilder::backoffExp);
return retryPolicyBuilder
.delayMillis(llmProviderClientConfig.getDelayMillis())
.build();
}

public ChatCompletionResponse create(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
log.info("Creating chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var openAiClient = getAndConfigureOpenAiClient(request, workspaceId);
ChatCompletionResponse chatCompletionResponse;
try {
chatCompletionResponse = retryPolicy.withRetry(() -> openAiClient.chatCompletion(request).execute());
} catch (RuntimeException runtimeException) {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, runtimeException);
if (runtimeException.getCause() instanceof OpenAiHttpException openAiHttpException) {
if (openAiHttpException.code() >= 400 && openAiHttpException.code() <= 499) {
throw new ClientErrorException(openAiHttpException.getMessage(), openAiHttpException.code());
}
throw new ServerErrorException(openAiHttpException.getMessage(), openAiHttpException.code());
}
throw new InternalServerErrorException(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
}
log.info("Created chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chatCompletionResponse;
}

public ChunkedOutput<String> createAndStreamResponse(
@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
log.info("Creating and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
var openAiClient = getAndConfigureOpenAiClient(request, workspaceId);
var chunkedOutput = new ChunkedOutput<String>(String.class, "\r\n");
openAiClient.chatCompletion(request)
.onPartialResponse(chatCompletionResponse -> send(chatCompletionResponse, chunkedOutput))
.onComplete(() -> close(chunkedOutput))
.onError(throwable -> handle(throwable, chunkedOutput))
.execute();
log.info("Created and streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return chunkedOutput;
}

private OpenAiClient getAndConfigureOpenAiClient(ChatCompletionRequest request, String workspaceId) {
var llmProvider = getLlmProvider(request.model());
var encryptedApiKey = getEncryptedApiKey(workspaceId, llmProvider);
return newOpenAiClient(encryptedApiKey);
}

/**
* The agreed requirement is to resolve the LLM provider and its API key based on the model.
* Currently, only OPEN AI is supported, so model param is ignored.
* No further validation is needed on the model, as it's just forwarded in the OPEN AI request and will be rejected
* if not valid.
*/
private LlmProvider getLlmProvider(String model) {
return LlmProvider.OPEN_AI;
}

/**
* Finding API keys isn't paginated at the moment, since only OPEN AI is supported.
* Even in the future, the number of supported LLM providers per workspace is going to be very low.
*/
private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) {
return llmProviderApiKeyService.find(workspaceId).content().stream()
.filter(providerApiKey -> llmProvider.equals(providerApiKey.provider()))
.findFirst()
.orElseThrow(() -> new BadRequestException("API key not configured for LLM provider '%s'".formatted(
llmProvider.getValue())))
.apiKey();
}

/**
* Initially, only OPEN AI is supported, so no need for a more sophisticated client resolution to start with.
* At the moment, openai4j client and also langchain4j wrappers, don't support dynamic API keys. That can imply
* an important performance penalty for next phases. The following options should be evaluated:
* - Cache clients, but can be unsafe.
* - Find and evaluate other clients.
* - Implement our own client.
* TODO as part of : <a href="https://comet-ml.atlassian.net/browse/OPIK-522">OPIK-522</a>
*/
private OpenAiClient newOpenAiClient(String encryptedApiKey) {
var openAiClientBuilder = OpenAiClient.builder();
Optional.ofNullable(llmProviderClientConfig.getOpenApiClient())
.map(LlmProviderClientConfig.OpenApiClientConfig::url)
.ifPresent(baseUrl -> {
if (StringUtils.isNotBlank(baseUrl)) {
openAiClientBuilder.baseUrl(baseUrl);
}
});
Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
.ifPresent(callTimeout -> openAiClientBuilder.callTimeout(callTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getConnectTimeout())
.ifPresent(connectTimeout -> openAiClientBuilder.connectTimeout(connectTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getReadTimeout())
.ifPresent(readTimeout -> openAiClientBuilder.readTimeout(readTimeout.toJavaDuration()));
Optional.ofNullable(llmProviderClientConfig.getWriteTimeout())
.ifPresent(writeTimeout -> openAiClientBuilder.writeTimeout(writeTimeout.toJavaDuration()));
return openAiClientBuilder
.openAiApiKey(EncryptionUtils.decrypt(encryptedApiKey))
.build();
}

private void send(Object item, ChunkedOutput<String> chunkedOutput) {
if (chunkedOutput.isClosed()) {
log.warn("Output stream is already closed");
return;
}
try {
chunkedOutput.write(JsonUtils.writeValueAsString(item));
} catch (IOException ioException) {
throw new UncheckedIOException(ioException);
}
}

private void handle(Throwable throwable, ChunkedOutput<String> chunkedOutput) {
log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, throwable);
var errorMessage = new ErrorMessage(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER);
if (throwable instanceof OpenAiHttpException openAiHttpException) {
errorMessage = new ErrorMessage(openAiHttpException.code(), openAiHttpException.getMessage());
}
try {
send(errorMessage, chunkedOutput);
} catch (UncheckedIOException uncheckedIOException) {
log.error("Failed to stream error message to client", uncheckedIOException);
}
close(chunkedOutput);
}

private void close(ChunkedOutput<String> chunkedOutput) {
try {
chunkedOutput.close();
} catch (IOException ioException) {
log.error("Failed to close output stream", ioException);
}
}
}
Loading
Loading