diff --git a/apps/opik-backend/pom.xml b/apps/opik-backend/pom.xml index 5877a63fa4..69591d788c 100644 --- a/apps/opik-backend/pom.xml +++ b/apps/opik-backend/pom.xml @@ -58,6 +58,13 @@ pom import + + dev.langchain4j + langchain4j-bom + 0.36.2 + pom + import + @@ -200,6 +207,10 @@ java-uuid-generator ${uuid.java.generator.version} + + dev.langchain4j + langchain4j-open-ai + diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResource.java new file mode 100644 index 0000000000..fd2b66d11d --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResource.java @@ -0,0 +1,94 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.codahale.metrics.annotation.Timed; +import com.comet.opik.domain.TextStreamer; +import com.comet.opik.infrastructure.auth.RequestContext; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import dev.ai4j.openai4j.shared.CompletionTokensDetails; +import dev.ai4j.openai4j.shared.Usage; +import io.dropwizard.jersey.errors.ErrorMessage; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.media.ArraySchema; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.validation.Valid; +import jakarta.validation.constraints.NotNull; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import reactor.core.publisher.Flux; + +import java.security.SecureRandom; +import java.util.UUID; + +@Path("/v1/private/chat/completions") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +@Timed +@Slf4j +@RequiredArgsConstructor(onConstructor_ = @Inject) +@Tag(name = "Chat Completions", description = "Chat Completions related resources") +public class ChatCompletionsResource { + + private final @NonNull Provider requestContextProvider; + private final @NonNull TextStreamer textStreamer; + private final @NonNull SecureRandom secureRandom; + + @POST + @Produces({MediaType.SERVER_SENT_EVENTS, MediaType.APPLICATION_JSON}) + @Operation(operationId = "getChatCompletions", summary = "Get chat completions", description = "Get chat completions", responses = { + @ApiResponse(responseCode = "501", description = "Chat completions response", content = { + @Content(mediaType = "text/event-stream", array = @ArraySchema(schema = @Schema(type = "object", anyOf = { + ChatCompletionResponse.class, + ErrorMessage.class}))), + @Content(mediaType = "application/json", schema = @Schema(implementation = ChatCompletionResponse.class))}), + }) + public Response get( + @RequestBody(content = @Content(schema = @Schema(implementation = ChatCompletionRequest.class))) @NotNull @Valid ChatCompletionRequest request) { + var workspaceId = requestContextProvider.get().getWorkspaceId(); + String type; + Object entity; + if (Boolean.TRUE.equals(request.stream())) { + log.info("Streaming chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); + type = MediaType.SERVER_SENT_EVENTS; + var flux = Flux.range(0, 10).map(i -> newResponse(request.model())); + entity = textStreamer.getOutputStream(flux); + } else { + log.info("Getting chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); + type = MediaType.APPLICATION_JSON; + entity = newResponse(request.model()); + } + var response = Response.status(Response.Status.NOT_IMPLEMENTED).type(type).entity(entity).build(); + log.info("Returned chat completions, workspaceId '{}', model '{}'", workspaceId, request.model()); + return response; + } + + private ChatCompletionResponse newResponse(String model) { + return ChatCompletionResponse.builder() + .id(UUID.randomUUID().toString()) + .created((int) (System.currentTimeMillis() / 1000)) + .model(model) + .usage(Usage.builder() + .totalTokens(secureRandom.nextInt()) + .promptTokens(secureRandom.nextInt()) + .completionTokens(secureRandom.nextInt()) + .completionTokensDetails(CompletionTokensDetails.builder() + .reasoningTokens(secureRandom.nextInt()) + .build()) + .build()) + .systemFingerprint(UUID.randomUUID().toString()) + .build(); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TextStreamer.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TextStreamer.java new file mode 100644 index 0000000000..a2151ef8dd --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TextStreamer.java @@ -0,0 +1,56 @@ +package com.comet.opik.domain; + +import com.comet.opik.utils.JsonUtils; +import io.dropwizard.jersey.errors.ErrorMessage; +import jakarta.inject.Singleton; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.glassfish.jersey.server.ChunkedOutput; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.concurrent.TimeoutException; + +@Singleton +@Slf4j +public class TextStreamer { + + public ChunkedOutput getOutputStream(@NonNull Flux flux) { + var outputStream = new ChunkedOutput(String.class, "\n"); + Schedulers.boundedElastic() + .schedule(() -> flux.doOnNext(item -> send(item, outputStream)) + .onErrorResume(throwable -> handleError(throwable, outputStream)) + .doFinally(signalType -> close(outputStream)) + .subscribe()); + return outputStream; + } + + private void send(Object item, ChunkedOutput outputStream) { + try { + outputStream.write(JsonUtils.writeValueAsString(item)); + } catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + private Flux handleError(Throwable throwable, ChunkedOutput outputStream) { + if (throwable instanceof TimeoutException) { + try { + send(new ErrorMessage(500, "Streaming operation timed out"), outputStream); + } catch (UncheckedIOException uncheckedIOException) { + log.error("Failed to stream error message to client", uncheckedIOException); + } + } + return Flux.error(throwable); + } + + private void close(ChunkedOutput outputStream) { + try { + outputStream.close(); + } catch (IOException ioException) { + log.error("Error while closing output stream", ioException); + } + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/ChatCompletionsClient.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/ChatCompletionsClient.java new file mode 100644 index 0000000000..13ba0b8255 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/utils/resources/ChatCompletionsClient.java @@ -0,0 +1,84 @@ +package com.comet.opik.api.resources.utils.resources; + +import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.utils.JsonUtils; +import com.fasterxml.jackson.core.type.TypeReference; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import dev.ai4j.openai4j.chat.ChatCompletionResponse; +import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.core.GenericType; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import lombok.Builder; +import org.apache.http.HttpStatus; +import org.glassfish.jersey.client.ChunkedInput; +import ru.vyarus.dropwizard.guice.test.ClientSupport; + +import java.util.ArrayList; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +public class ChatCompletionsClient { + + private static final String RESOURCE_PATH = "%s/v1/private/chat/completions"; + + private static final GenericType> CHUNKED_INPUT_STRING_GENERIC_TYPE = new GenericType<>() { + }; + + private static final TypeReference CHAT_COMPLETION_RESPONSE_TYPE_REFERENCE = new TypeReference<>() { + }; + + private final ClientSupport clientSupport; + private final String baseURI; + + public ChatCompletionsClient(ClientSupport clientSupport) { + this.clientSupport = clientSupport; + this.baseURI = "http://localhost:%d".formatted(clientSupport.getPort()); + } + + public ChatCompletionResponse get(String apiKey, String workspaceName, ChatCompletionRequest request) { + assertThat(request.stream()).isFalse(); + + try (var response = clientSupport.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .post(Entity.json(request))) { + + assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NOT_IMPLEMENTED); + + return response.readEntity(ChatCompletionResponse.class); + } + } + + public List getStream(String apiKey, String workspaceName, ChatCompletionRequest request) { + assertThat(request.stream()).isTrue(); + + try (var response = clientSupport.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .accept(MediaType.SERVER_SENT_EVENTS) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .post(Entity.json(request))) { + + assertThat(response.getStatusInfo().getStatusCode()).isEqualTo(HttpStatus.SC_NOT_IMPLEMENTED); + + return getStreamedItems(response); + } + } + + private List getStreamedItems(Response response) { + var items = new ArrayList(); + try (var inputStream = response.readEntity(CHUNKED_INPUT_STRING_GENERIC_TYPE)) { + inputStream.setParser(ChunkedInput.createParser("\n")); + String stringItem; + while ((stringItem = inputStream.read()) != null) { + items.add(JsonUtils.readValue(stringItem, CHAT_COMPLETION_RESPONSE_TYPE_REFERENCE)); + } + } + return items; + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java new file mode 100644 index 0000000000..14fe8817b8 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/ChatCompletionsResourceTest.java @@ -0,0 +1,112 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; +import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.MigrationUtils; +import com.comet.opik.api.resources.utils.MySQLContainerUtils; +import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.api.resources.utils.resources.ChatCompletionsClient; +import com.comet.opik.podam.PodamFactoryUtils; +import com.redis.testcontainers.RedisContainer; +import dev.ai4j.openai4j.chat.ChatCompletionRequest; +import lombok.extern.slf4j.Slf4j; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.clickhouse.ClickHouseContainer; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.lifecycle.Startables; +import org.testcontainers.shaded.org.apache.commons.lang3.RandomStringUtils; +import ru.vyarus.dropwizard.guice.test.ClientSupport; +import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; +import uk.co.jemos.podam.api.PodamFactory; + +import java.sql.SQLException; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@Slf4j +public class ChatCompletionsResourceTest { + + private static final String API_KEY = RandomStringUtils.randomAlphanumeric(25); + private static final String WORKSPACE_ID = UUID.randomUUID().toString(); + private static final String WORKSPACE_NAME = RandomStringUtils.randomAlphanumeric(20); + private static final String USER = RandomStringUtils.randomAlphanumeric(20); + + private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); + private static final MySQLContainer MY_SQL_CONTAINER = MySQLContainerUtils.newMySQLContainer(); + private static final ClickHouseContainer CLICK_HOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer(); + + private static final WireMockUtils.WireMockRuntime wireMock = WireMockUtils.startWireMock(); + + @RegisterExtension + private static final TestDropwizardAppExtension app; + + static { + Startables.deepStart(REDIS, MY_SQL_CONTAINER, CLICK_HOUSE_CONTAINER).join(); + + var databaseAnalyticsFactory = ClickHouseContainerUtils.newDatabaseAnalyticsFactory( + CLICK_HOUSE_CONTAINER, ClickHouseContainerUtils.DATABASE_NAME); + + app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + MY_SQL_CONTAINER.getJdbcUrl(), databaseAnalyticsFactory, wireMock.runtimeInfo(), REDIS.getRedisURI()); + } + + private final PodamFactory podamFactory = PodamFactoryUtils.newPodamFactory(); + + private ChatCompletionsClient chatCompletionsClient; + + @BeforeAll + void setUpAll(ClientSupport clientSupport, Jdbi jdbi) throws SQLException { + MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); + + try (var connection = CLICK_HOUSE_CONTAINER.createConnection("")) { + MigrationUtils.runDbMigration( + connection, + MigrationUtils.CLICKHOUSE_CHANGELOG_FILE, + ClickHouseContainerUtils.migrationParameters() + ); + } + + ClientSupportUtils.config(clientSupport); + + mockTargetWorkspace(API_KEY, WORKSPACE_NAME, WORKSPACE_ID); + + this.chatCompletionsClient = new ChatCompletionsClient(clientSupport); + } + + private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { + AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); + } + + @Nested + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class Get { + + @Test + void get() { + var request = podamFactory.manufacturePojo(ChatCompletionRequest.Builder.class).stream(false).build(); + + var response = chatCompletionsClient.get(API_KEY, WORKSPACE_NAME, request); + + assertThat(response).isNotNull(); + } + + @Test + void getStream() { + var request = podamFactory.manufacturePojo(ChatCompletionRequest.Builder.class).stream(true).build(); + + var response = chatCompletionsClient.getStream(API_KEY, WORKSPACE_NAME, request); + + assertThat(response).hasSize(10); + } + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index 79a04f0549..5c9c7bc08e 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -4102,12 +4102,12 @@ void find__whenFilteringBySupportedFields__thenReturnMatchingRows(Filter filter) List filters = List.of(filter); - var expectedDatasetItems = filter.operator() != Operator.NOT_EQUAL ? - List.of(datasetItems.getFirst()) : - datasetItems.subList(1, datasetItems.size()).reversed(); - var expectedExperimentItems = filter.operator() != Operator.NOT_EQUAL ? - List.of(experimentItems.getFirst()) : - experimentItems.subList(1, experimentItems.size()).reversed(); + var expectedDatasetItems = filter.operator() != Operator.NOT_EQUAL + ? List.of(datasetItems.getFirst()) + : datasetItems.subList(1, datasetItems.size()).reversed(); + var expectedExperimentItems = filter.operator() != Operator.NOT_EQUAL + ? List.of(experimentItems.getFirst()) + : experimentItems.subList(1, experimentItems.size()).reversed(); var actualPage = assertDatasetExperimentPage(datasetId, experimentId, filters, apiKey, workspaceName, columns, expectedDatasetItems); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java index 29999f5f64..64d6d8db67 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/TracesResourceTest.java @@ -1032,8 +1032,8 @@ private Stream equalAndNotEqualFilters() { @ParameterizedTest @MethodSource("equalAndNotEqualFilters") void getByProjectName__whenFilterIdAndNameEqual__thenReturnTracesFiltered(Operator operator, - Function, List> getExpectedTraces, - Function, List> getUnexpectedTraces) { + Function, List> getExpectedTraces, + Function, List> getUnexpectedTraces) { var workspaceName = RandomStringUtils.randomAlphanumeric(10); var workspaceId = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -1065,7 +1065,8 @@ void getByProjectName__whenFilterIdAndNameEqual__thenReturnTracesFiltered(Operat .operator(operator) .value(traces.getFirst().name()) .build()); - getAndAssertPage(workspaceName, projectName, filters, traces, expectedTraces.reversed(), unexpectedTraces, apiKey); + getAndAssertPage(workspaceName, projectName, filters, traces, expectedTraces.reversed(), unexpectedTraces, + apiKey); } @Test @@ -1241,8 +1242,8 @@ void getByProjectName__whenFilterNameNotContains__thenReturnTracesFiltered() { @ParameterizedTest @MethodSource("equalAndNotEqualFilters") void getByProjectName__whenFilterStartTimeEqual__thenReturnTracesFiltered(Operator operator, - Function, List> getExpectedTraces, - Function, List> getUnexpectedTraces) { + Function, List> getExpectedTraces, + Function, List> getUnexpectedTraces) { var workspaceName = RandomStringUtils.randomAlphanumeric(10); var workspaceId = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -1268,7 +1269,8 @@ void getByProjectName__whenFilterStartTimeEqual__thenReturnTracesFiltered(Operat .operator(operator) .value(traces.getFirst().startTime().toString()) .build()); - getAndAssertPage(workspaceName, projectName, filters, traces, expectedTraces.reversed(), unexpectedTraces, apiKey); + getAndAssertPage(workspaceName, projectName, filters, traces, expectedTraces.reversed(), unexpectedTraces, + apiKey); } @Test @@ -1567,8 +1569,8 @@ void getByProjectName__whenFilterTotalEstimatedCostGreaterThen__thenReturnTraces @ParameterizedTest @MethodSource("equalAndNotEqualFilters") void getByProjectName__whenFilterTotalEstimatedCostEqual_NotEqual__thenReturnTracesFiltered(Operator operator, - Function, List> getUnexpectedTraces, - Function, List> getExpectedTraces) { + Function, List> getUnexpectedTraces, + Function, List> getExpectedTraces) { var workspaceName = RandomStringUtils.randomAlphanumeric(10); var workspaceId = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -1618,8 +1620,8 @@ void getByProjectName__whenFilterTotalEstimatedCostEqual_NotEqual__thenReturnTra @ParameterizedTest @MethodSource("equalAndNotEqualFilters") void getByProjectName__whenFilterMetadataEqualString__thenReturnTracesFiltered(Operator operator, - Function, List> getExpectedTraces, - Function, List> getUnexpectedTraces) { + Function, List> getExpectedTraces, + Function, List> getUnexpectedTraces) { var workspaceName = RandomStringUtils.randomAlphanumeric(10); var workspaceId = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -1652,7 +1654,8 @@ void getByProjectName__whenFilterMetadataEqualString__thenReturnTracesFiltered(O .key("$.model[0].version") .value("OPENAI, CHAT-GPT 4.0") .build()); - getAndAssertPage(workspaceName, projectName, filters, traces, expectedTraces.reversed(), unexpectedTraces, apiKey); + getAndAssertPage(workspaceName, projectName, filters, traces, expectedTraces.reversed(), unexpectedTraces, + apiKey); } @Test @@ -2516,8 +2519,8 @@ void getByProjectName__whenFilterUsageLessThanEqual__thenReturnTracesFiltered(St @ParameterizedTest @MethodSource void getByProjectName__whenFilterFeedbackScoresEqual__thenReturnTracesFiltered(Operator operator, - Function, List> getExpectedTraces, - Function, List> getUnexpectedTraces) { + Function, List> getExpectedTraces, + Function, List> getUnexpectedTraces) { var workspaceName = RandomStringUtils.randomAlphanumeric(10); var workspaceId = UUID.randomUUID().toString(); var apiKey = UUID.randomUUID().toString(); @@ -2561,7 +2564,8 @@ void getByProjectName__whenFilterFeedbackScoresEqual__thenReturnTracesFiltered(O .key(traces.getFirst().feedbackScores().get(2).name().toUpperCase()) .value(traces.getFirst().feedbackScores().get(2).value().toString()) .build()); - getAndAssertPage(workspaceName, projectName, filters, traces, expectedTraces.reversed(), unexpectedTraces, apiKey); + getAndAssertPage(workspaceName, projectName, filters, traces, expectedTraces.reversed(), unexpectedTraces, + apiKey); } private Stream getByProjectName__whenFilterFeedbackScoresEqual__thenReturnTracesFiltered() {