Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPIK-545: Add chat completions resource #852

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions apps/opik-backend/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@
<type>pom</type>
<scope>import</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-bom</artifactId>
<version>0.36.2</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down Expand Up @@ -200,6 +207,10 @@
<artifactId>java-uuid-generator</artifactId>
<version>${uuid.java.generator.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
</dependency>

<!-- Test -->

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package com.comet.opik.api.resources.v1.priv;

import com.codahale.metrics.annotation.Timed;
import com.comet.opik.domain.TextStreamer;
import com.comet.opik.infrastructure.auth.RequestContext;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.shared.CompletionTokensDetails;
import dev.ai4j.openai4j.shared.Usage;
import io.dropwizard.jersey.errors.ErrorMessage;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.media.ArraySchema;
import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.validation.Valid;
import jakarta.validation.constraints.NotNull;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import reactor.core.publisher.Flux;

import java.security.SecureRandom;
import java.util.UUID;

@Path("/v1/private/chat/completions")
@Produces(MediaType.APPLICATION_JSON)
@Consumes(MediaType.APPLICATION_JSON)
@Timed
@Slf4j
@RequiredArgsConstructor(onConstructor_ = @Inject)
@Tag(name = "Chat Completions", description = "Chat Completions related resources")
public class ChatCompletionsResource {

private final @NonNull Provider<RequestContext> requestContextProvider;
private final @NonNull TextStreamer textStreamer;
private final @NonNull SecureRandom secureRandom;

@POST
@Produces({MediaType.SERVER_SENT_EVENTS, MediaType.APPLICATION_JSON})
@Operation(operationId = "getChatCompletions", summary = "Get chat completions", description = "Get chat completions", responses = {
@ApiResponse(responseCode = "501", description = "Chat completions response", content = {
@Content(mediaType = "text/event-stream", array = @ArraySchema(schema = @Schema(type = "object", anyOf = {
ChatCompletionResponse.class,
ErrorMessage.class}))),
@Content(mediaType = "application/json", schema = @Schema(implementation = ChatCompletionResponse.class))}),
})
public Response get(
@RequestBody(content = @Content(schema = @Schema(implementation = ChatCompletionRequest.class))) @NotNull @Valid ChatCompletionRequest request) {
var workspaceId = requestContextProvider.get().getWorkspaceId();
String type;
Object entity;
if (Boolean.TRUE.equals(request.stream())) {
log.info("Streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
type = MediaType.SERVER_SENT_EVENTS;
var flux = Flux.range(0, 10).map(i -> newResponse(request.model()));
entity = textStreamer.getOutputStream(flux);
} else {
log.info("Getting chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
type = MediaType.APPLICATION_JSON;
entity = newResponse(request.model());
}
var response = Response.status(Response.Status.NOT_IMPLEMENTED).type(type).entity(entity).build();
log.info("Returned chat completions, workspaceId '{}', model '{}'", workspaceId, request.model());
return response;
}

private ChatCompletionResponse newResponse(String model) {
return ChatCompletionResponse.builder()
.id(UUID.randomUUID().toString())
.created((int) (System.currentTimeMillis() / 1000))
.model(model)
.usage(Usage.builder()
.totalTokens(secureRandom.nextInt())
.promptTokens(secureRandom.nextInt())
.completionTokens(secureRandom.nextInt())
.completionTokensDetails(CompletionTokensDetails.builder()
.reasoningTokens(secureRandom.nextInt())
.build())
.build())
.systemFingerprint(UUID.randomUUID().toString())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.comet.opik.domain;

import com.comet.opik.utils.JsonUtils;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Singleton;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.glassfish.jersey.server.ChunkedOutput;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.concurrent.TimeoutException;

@Singleton
@Slf4j
public class TextStreamer {

public ChunkedOutput<String> getOutputStream(@NonNull Flux<?> flux) {
var outputStream = new ChunkedOutput<String>(String.class, "\n");
andrescrz marked this conversation as resolved.
Show resolved Hide resolved
Schedulers.boundedElastic()
.schedule(() -> flux.doOnNext(item -> send(item, outputStream))
.onErrorResume(throwable -> handleError(throwable, outputStream))
.doFinally(signalType -> close(outputStream))
.subscribe());
return outputStream;
}

private void send(Object item, ChunkedOutput<String> outputStream) {
try {
outputStream.write(JsonUtils.writeValueAsString(item));
} catch (IOException exception) {
throw new UncheckedIOException(exception);
}
}

private <T> Flux<T> handleError(Throwable throwable, ChunkedOutput<String> outputStream) {
if (throwable instanceof TimeoutException) {
try {
send(new ErrorMessage(500, "Streaming operation timed out"), outputStream);
} catch (UncheckedIOException uncheckedIOException) {
log.error("Failed to stream error message to client", uncheckedIOException);
}
}
return Flux.error(throwable);
}

private void close(ChunkedOutput<String> outputStream) {
try {
outputStream.close();
} catch (IOException ioException) {
log.error("Error while closing output stream", ioException);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package com.comet.opik.api.resources.utils.resources;

import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.core.type.TypeReference;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.GenericType;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
import lombok.Builder;
import org.apache.http.HttpStatus;
import org.glassfish.jersey.client.ChunkedInput;
import ru.vyarus.dropwizard.guice.test.ClientSupport;

import java.util.ArrayList;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;

public class ChatCompletionsClient {

private static final String RESOURCE_PATH = "%s/v1/private/chat/completions";

private static final GenericType<ChunkedInput<String>> CHUNKED_INPUT_STRING_GENERIC_TYPE = new GenericType<>() {
};

private static final TypeReference<ChatCompletionResponse> CHAT_COMPLETION_RESPONSE_TYPE_REFERENCE = new TypeReference<>() {
};

private final ClientSupport clientSupport;
private final String baseURI;

public ChatCompletionsClient(ClientSupport clientSupport) {
this.clientSupport = clientSupport;
this.baseURI = "http://localhost:%d".formatted(clientSupport.getPort());
}

public ChatCompletionResponse get(String apiKey, String workspaceName, ChatCompletionRequest request) {
assertThat(request.stream()).isFalse();

try (var response = clientSupport.target(RESOURCE_PATH.formatted(baseURI))
.request()
.accept(MediaType.APPLICATION_JSON_TYPE)
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(RequestContext.WORKSPACE_HEADER, workspaceName)
.post(Entity.json(request))) {

assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NOT_IMPLEMENTED);

return response.readEntity(ChatCompletionResponse.class);
}
}

public List<ChatCompletionResponse> getStream(String apiKey, String workspaceName, ChatCompletionRequest request) {
assertThat(request.stream()).isTrue();

try (var response = clientSupport.target(RESOURCE_PATH.formatted(baseURI))
.request()
.accept(MediaType.SERVER_SENT_EVENTS)
.header(HttpHeaders.AUTHORIZATION, apiKey)
.header(RequestContext.WORKSPACE_HEADER, workspaceName)
.post(Entity.json(request))) {

assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NOT_IMPLEMENTED);

return getStreamedItems(response);
}
}

private List<ChatCompletionResponse> getStreamedItems(Response response) {
var items = new ArrayList<ChatCompletionResponse>();
try (var inputStream = response.readEntity(CHUNKED_INPUT_STRING_GENERIC_TYPE)) {
inputStream.setParser(ChunkedInput.createParser("\n"));
String stringItem;
while ((stringItem = inputStream.read()) != null) {
items.add(JsonUtils.readValue(stringItem, CHAT_COMPLETION_RESPONSE_TYPE_REFERENCE));
}
}
return items;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package com.comet.opik.api.resources.v1.priv;

import com.comet.opik.api.resources.utils.AuthTestUtils;
import com.comet.opik.api.resources.utils.ClickHouseContainerUtils;
import com.comet.opik.api.resources.utils.ClientSupportUtils;
import com.comet.opik.api.resources.utils.MigrationUtils;
import com.comet.opik.api.resources.utils.MySQLContainerUtils;
import com.comet.opik.api.resources.utils.RedisContainerUtils;
import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils;
import com.comet.opik.api.resources.utils.WireMockUtils;
import com.comet.opik.api.resources.utils.resources.ChatCompletionsClient;
import com.comet.opik.podam.PodamFactoryUtils;
import com.redis.testcontainers.RedisContainer;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import lombok.extern.slf4j.Slf4j;
import org.jdbi.v3.core.Jdbi;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.testcontainers.clickhouse.ClickHouseContainer;
import org.testcontainers.containers.MySQLContainer;
import org.testcontainers.lifecycle.Startables;
import org.testcontainers.shaded.org.apache.commons.lang3.RandomStringUtils;
import ru.vyarus.dropwizard.guice.test.ClientSupport;
import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension;
import uk.co.jemos.podam.api.PodamFactory;

import java.sql.SQLException;
import java.util.UUID;

import static org.assertj.core.api.Assertions.assertThat;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@Slf4j
public class ChatCompletionsResourceTest {

private static final String API_KEY = RandomStringUtils.randomAlphanumeric(25);
private static final String WORKSPACE_ID = UUID.randomUUID().toString();
private static final String WORKSPACE_NAME = RandomStringUtils.randomAlphanumeric(20);
private static final String USER = RandomStringUtils.randomAlphanumeric(20);

private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer();
private static final MySQLContainer<?> MY_SQL_CONTAINER = MySQLContainerUtils.newMySQLContainer();
private static final ClickHouseContainer CLICK_HOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer();

private static final WireMockUtils.WireMockRuntime wireMock = WireMockUtils.startWireMock();

@RegisterExtension
private static final TestDropwizardAppExtension app;

static {
Startables.deepStart(REDIS, MY_SQL_CONTAINER, CLICK_HOUSE_CONTAINER).join();

var databaseAnalyticsFactory = ClickHouseContainerUtils.newDatabaseAnalyticsFactory(
CLICK_HOUSE_CONTAINER, ClickHouseContainerUtils.DATABASE_NAME);

app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension(
MY_SQL_CONTAINER.getJdbcUrl(), databaseAnalyticsFactory, wireMock.runtimeInfo(), REDIS.getRedisURI());
}

private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory();

private ChatCompletionsClient chatCompletionsClient;

@BeforeAll
void setUpAll(ClientSupport clientSupport, Jdbi jdbi) throws SQLException {
MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters());

try (var connection = CLICK_HOUSE_CONTAINER.createConnection("")) {
MigrationUtils.runDbMigration(
connection,
MigrationUtils.CLICKHOUSE_CHANGELOG_FILE,
ClickHouseContainerUtils.migrationParameters()
);
}

ClientSupportUtils.config(clientSupport);

mockTargetWorkspace(API_KEY, WORKSPACE_NAME, WORKSPACE_ID);

this.chatCompletionsClient = new ChatCompletionsClient(clientSupport);
}

private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) {
AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER);
}

@Nested
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class Get {

@Test
void get() {
var request = podamFactory.manufacturePojo(ChatCompletionRequest.Builder.class).stream(false).build();

var response = chatCompletionsClient.get(API_KEY, WORKSPACE_NAME, request);

assertThat(response).isNotNull();
}

@Test
void getStream() {
var request = podamFactory.manufacturePojo(ChatCompletionRequest.Builder.class).stream(true).build();

var response = chatCompletionsClient.getStream(API_KEY, WORKSPACE_NAME, request);

assertThat(response).hasSize(10);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4102,12 +4102,12 @@ void find__whenFilteringBySupportedFields__thenReturnMatchingRows(Filter filter)

List<Filter> filters = List.of(filter);

var expectedDatasetItems = filter.operator() != Operator.NOT_EQUAL ?
List.of(datasetItems.getFirst()) :
datasetItems.subList(1, datasetItems.size()).reversed();
var expectedExperimentItems = filter.operator() != Operator.NOT_EQUAL ?
List.of(experimentItems.getFirst()) :
experimentItems.subList(1, experimentItems.size()).reversed();
var expectedDatasetItems = filter.operator() != Operator.NOT_EQUAL
? List.of(datasetItems.getFirst())
: datasetItems.subList(1, datasetItems.size()).reversed();
var expectedExperimentItems = filter.operator() != Operator.NOT_EQUAL
? List.of(experimentItems.getFirst())
: experimentItems.subList(1, experimentItems.size()).reversed();

var actualPage = assertDatasetExperimentPage(datasetId, experimentId, filters, apiKey, workspaceName,
columns, expectedDatasetItems);
Expand Down
Loading
Loading