From dceb931c330fa10489f4e69efd0a4d56619e8e1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Cre=C3=A3o?= Date: Tue, 14 Jan 2025 16:25:04 +0000 Subject: [PATCH] [OPIK-594] Triggering LLM calls to score after Traces are received (#1038) * Triggering LLM calls to score after Traces are received * Adapting LLM provider to the new format. Fixing missing traceId in the score when we moved into the batch storing. * PR change requests * running spotless * removing ununsed imports * Update apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java Co-authored-by: Thiago dos Santos Hora * Update apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java Co-authored-by: Thiago dos Santos Hora * applying PR requests --------- Co-authored-by: Thiago dos Santos Hora --- .../opik/api/FeedbackScoreBatchItem.java | 1 + .../java/com/comet/opik/api/ScoreSource.java | 3 +- .../v1/events/LlmAsJudgeMessageRender.java | 136 --------- .../v1/events/OnlineScoringEngine.java | 267 ++++++++++++++++++ .../v1/events/OnlineScoringEventListener.java | 50 ++-- .../opik/domain/ChatCompletionService.java | 23 ++ .../com/comet/opik/domain/TraceService.java | 5 +- .../LlmProviderClientGenerator.java | 24 ++ .../llmproviders/LlmProviderFactory.java | 12 + .../000009_extend_feedback_source_type.sql | 7 + .../events/LlmAsJudgeMessageRenderTest.java | 142 ---------- .../v1/events/OnlineScoringEngineTest.java | 241 ++++++++++++++++ .../OnlineScoringEventListenerTest.java | 139 --------- 13 files changed, 611 insertions(+), 439 deletions(-) delete mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/LlmAsJudgeMessageRender.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java create mode 100644 apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000009_extend_feedback_source_type.sql delete mode 100644 apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/LlmAsJudgeMessageRenderTest.java create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java delete mode 100644 apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEventListenerTest.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatchItem.java b/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatchItem.java index 5f89716009..e2718b649e 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatchItem.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/FeedbackScoreBatchItem.java @@ -23,6 +23,7 @@ @JsonIgnoreProperties(ignoreUnknown = true) @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public record FeedbackScoreBatchItem( + // entity (trace or span) id @NotNull UUID id, @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Schema(description = "If null, the default project is used") String projectName, @JsonIgnore UUID projectId, diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/ScoreSource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/ScoreSource.java index 2c47b722b4..84b3482732 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/ScoreSource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/ScoreSource.java @@ -10,7 +10,8 @@ @RequiredArgsConstructor public enum ScoreSource { UI("ui"), - SDK("sdk"); + SDK("sdk"), + ONLINE_SCORING("online_scoring"); @JsonValue private final String value; diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/LlmAsJudgeMessageRender.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/LlmAsJudgeMessageRender.java deleted file mode 100644 index 91c49c486d..0000000000 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/LlmAsJudgeMessageRender.java +++ /dev/null @@ -1,136 +0,0 @@ -package com.comet.opik.api.resources.v1.events; - -import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; -import com.comet.opik.api.Trace; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.jayway.jsonpath.JsonPath; -import dev.ai4j.openai4j.chat.Message; -import dev.ai4j.openai4j.chat.SystemMessage; -import dev.ai4j.openai4j.chat.UserMessage; -import lombok.Builder; -import lombok.experimental.UtilityClass; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.text.StringSubstitutor; - -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - -@UtilityClass -@Slf4j -class LlmAsJudgeMessageRender { - - /** - * Render the rule evaluator message template using the values from an actual trace. - * - * As the rule my consist in multiple messages, we check each one of them for variables to fill. - * Then we go through every variable template to replace them for the value from the trace. - * - * @param trace the trace with value to use to replace template variables - * @param evaluatorCode the evaluator - * @return a list of AI messages, with templates rendered - */ - public static List renderMessages(Trace trace, - AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode evaluatorCode) { - // prepare the map of replacements to use in all messages - var parsedVariables = variableMapping(evaluatorCode.variables()); - - // extract the actual value from the Trace - var replacements = parsedVariables.stream().map(mapper -> { - var traceSection = switch (mapper.traceSection) { - case INPUT -> trace.input(); - case OUTPUT -> trace.output(); - case METADATA -> trace.metadata(); - }; - - return mapper.toBuilder() - .valueToReplace(extractFromJson(traceSection, mapper.jsonPath())) - .build(); - }) - .filter(mapper -> mapper.valueToReplace() != null) - .collect( - Collectors.toMap(LlmAsJudgeMessageRender.MessageVariableMapping::variableName, - LlmAsJudgeMessageRender.MessageVariableMapping::valueToReplace)); - - // will convert all '{{key}}' into 'value' - // TODO: replace with Mustache Java to be in confirm with FE - var templateRenderer = new StringSubstitutor(replacements, "{{", "}}"); - - // render the message templates from evaluator rule - return evaluatorCode.messages().stream() - .map(templateMessage -> { - var renderedMessage = templateRenderer.replace(templateMessage.content()); - - return switch (templateMessage.role()) { - case USER -> UserMessage.from(renderedMessage); - case SYSTEM -> SystemMessage.from(renderedMessage); - default -> { - log.info("No mapping for message role type {}", templateMessage.role()); - yield null; - } - }; - }) - .filter(Objects::nonNull) - .toList(); - } - - /** - * Parse evaluator\'s variable mapper into an usable list of - * - * @param evaluatorVariables a map with variables and a path into a trace input/output/metadata to replace - * @return a parsed list of mappings, easier to use for the template rendering - */ - public static List variableMapping(Map evaluatorVariables) { - return evaluatorVariables.entrySet().stream() - .map(mapper -> { - var templateVariable = mapper.getKey(); - var tracePath = mapper.getValue(); - - var builder = MessageVariableMapping.builder().variableName(templateVariable); - - if (tracePath.startsWith("input.")) { - builder.traceSection(TraceSection.INPUT) - .jsonPath("$" + tracePath.substring("input".length())); - } else if (tracePath.startsWith("output.")) { - builder.traceSection(TraceSection.OUTPUT) - .jsonPath("$" + tracePath.substring("output".length())); - } else if (tracePath.startsWith("metadata.")) { - builder.traceSection(TraceSection.METADATA) - .jsonPath("$" + tracePath.substring("metadata".length())); - } else { - log.info("Couldn't map trace path '{}' into a input/output/metadata path", tracePath); - return null; - } - - return builder.build(); - }) - .filter(Objects::nonNull) - .toList(); - } - - final ObjectMapper objectMapper = new ObjectMapper(); - - String extractFromJson(JsonNode json, String path) { - try { - // JsonPath didnt work with JsonNode, even explicitly using JacksonJsonProvider, so we convert to a Map - var forcedObject = objectMapper.convertValue(json, Map.class); - return JsonPath.parse(forcedObject).read(path); - } catch (Exception e) { - log.debug("Couldn't find path '{}' inside json {}: {}", path, json, e.getMessage()); - return null; - } - } - - public enum TraceSection { - INPUT, - OUTPUT, - METADATA - } - - @Builder(toBuilder = true) - public record MessageVariableMapping(TraceSection traceSection, String variableName, String jsonPath, - String valueToReplace) { - } -} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java new file mode 100644 index 0000000000..a69944a29d --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java @@ -0,0 +1,267 @@ +package com.comet.opik.api.resources.v1.events; + +import com.comet.opik.api.FeedbackScoreBatchItem; +import com.comet.opik.api.ScoreSource; +import com.comet.opik.api.Trace; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.jayway.jsonpath.JsonPath; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.chat.request.ResponseFormatType; +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.request.json.JsonSchemaElement; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; +import dev.langchain4j.model.chat.response.ChatResponse; +import jakarta.validation.constraints.NotNull; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.experimental.UtilityClass; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.text.StringSubstitutor; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode; +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeMessage; +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeOutputSchema; + +@UtilityClass +@Slf4j +public class OnlineScoringEngine { + + static final String SCORE_FIELD_NAME = "score"; + static final String REASON_FIELD_NAME = "reason"; + static final String SCORE_FIELD_DESCRIPTION = "the score for "; + static final String REASON_FIELD_DESCRIPTION = "the reason for the score for "; + static final String DEFAULT_SCHEMA_NAME = "scoring_schema"; + + /** + * Prepare a request to a LLM-as-Judge evaluator (a ChatLanguageModel) rendering the template messages with + * Trace variables and with the proper structured output format. + * + * @param evaluatorCode the LLM-as-Judge 'code' + * @param trace the sampled Trace to be scored + * @return a request to trigger to any supported provider with a ChatLanguageModel + */ + public static ChatRequest prepareLlmRequest(@NotNull LlmAsJudgeCode evaluatorCode, + Trace trace) { + var responseFormat = toResponseFormat(evaluatorCode.schema()); + var renderedMessages = renderMessages(evaluatorCode.messages(), evaluatorCode.variables(), trace); + + return ChatRequest.builder() + .messages(renderedMessages) + .responseFormat(responseFormat) + .build(); + } + + /** + * Render the rule evaluator message template using the values from an actual trace. + *

+ * As the rule may consist in multiple messages, we check each one of them for variables to fill. + * Then we go through every variable template to replace them for the value from the trace. + * + * @param templateMessages a list of messages with variables to fill with a Trace value + * @param variablesMap a map of template variable to a path to a value into a Trace + * @param trace the trace with value to use to replace template variables + * @return a list of AI messages, with templates rendered + */ + static List renderMessages(List templateMessages, Map variablesMap, + Trace trace) { + // prepare the map of replacements to use in all messages + var parsedVariables = toVariableMapping(variablesMap); + + // extract the actual value from the Trace + var replacements = parsedVariables.stream().map(mapper -> { + var traceSection = switch (mapper.traceSection()) { + case INPUT -> trace.input(); + case OUTPUT -> trace.output(); + case METADATA -> trace.metadata(); + }; + + return mapper.toBuilder() + .valueToReplace(extractFromJson(traceSection, mapper.jsonPath())) + .build(); + }) + .filter(mapper -> mapper.valueToReplace() != null) + .collect( + Collectors.toMap(MessageVariableMapping::variableName, MessageVariableMapping::valueToReplace)); + + // will convert all '{{key}}' into 'value' + // TODO: replace with Mustache Java to be in confirm with FE + var templateRenderer = new StringSubstitutor(replacements, "{{", "}}"); + + // render the message templates from evaluator rule + return templateMessages.stream() + .map(templateMessage -> { + var renderedMessage = templateRenderer.replace(templateMessage.content()); + + return switch (templateMessage.role()) { + case USER -> UserMessage.from(renderedMessage); + case SYSTEM -> SystemMessage.from(renderedMessage); + + default -> { + log.info("No mapping for message role type {}", templateMessage.role()); + yield null; + } + }; + }) + .filter(Objects::nonNull) + .toList(); + } + + /** + * Parse evaluator's variable mapper into a usable list of mappings. + * + * @param evaluatorVariables a map with variables and a path into a trace input/output/metadata to replace + * @return a parsed list of mappings, easier to use for the template rendering + */ + static List toVariableMapping(Map evaluatorVariables) { + return evaluatorVariables.entrySet().stream() + .map(mapper -> { + var templateVariable = mapper.getKey(); + var tracePath = mapper.getValue(); + + var builder = MessageVariableMapping.builder().variableName(templateVariable); + + // check if its input/output/metadata variable and fix the json path + Arrays.stream(TraceSection.values()) + .filter(traceSection -> tracePath.startsWith(traceSection.prefix)) + .findFirst() + .ifPresent(traceSection -> builder.traceSection(traceSection) + .jsonPath("$." + tracePath.substring(traceSection.prefix.length()))); + + return builder.build(); + }) + .filter(Objects::nonNull) + .toList(); + } + + final ObjectMapper objectMapper = new ObjectMapper(); + + String extractFromJson(JsonNode json, String path) { + try { + // JsonPath didnt work with JsonNode, even explicitly using JacksonJsonProvider, so we convert to a Map + var forcedObject = objectMapper.convertValue(json, Map.class); + return JsonPath.parse(forcedObject).read(path); + } catch (Exception e) { + log.debug("Couldn't find path '{}' inside json {}: {}", path, json, e.getMessage()); + return null; + } + } + + static ResponseFormat toResponseFormat(@NotNull List schema) { + // convert into something like + // "${name}": { "score": { "type": "${type}" , "description": ${description}", "reason": { "type" : "string" }} + Map structuredFields = schema.stream() + .map(scoreDefinition -> Map.entry(scoreDefinition.name(), + JsonObjectSchema.builder() + .description(scoreDefinition.description()) + .required(SCORE_FIELD_NAME, REASON_FIELD_NAME) + .properties(Map.of( + SCORE_FIELD_NAME, switch (scoreDefinition.type()) { + case BOOLEAN -> JsonBooleanSchema.builder() + .description(SCORE_FIELD_DESCRIPTION + scoreDefinition.name()) + .build(); + case INTEGER -> JsonIntegerSchema.builder() + .description(SCORE_FIELD_DESCRIPTION + scoreDefinition.name()) + .build(); + case DOUBLE -> JsonNumberSchema.builder() + .description(SCORE_FIELD_DESCRIPTION + scoreDefinition.name()) + .build(); + }, + REASON_FIELD_NAME, + JsonStringSchema.builder() + .description(REASON_FIELD_DESCRIPTION + scoreDefinition.name()) + .build())) + .build() + + )) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + var allPropertyNames = structuredFields.keySet().stream().toList(); + + var schemaBuilder = JsonObjectSchema.builder().required(allPropertyNames).properties(structuredFields).build(); + + var jsonSchema = JsonSchema.builder().name(DEFAULT_SCHEMA_NAME).rootElement(schemaBuilder).build(); + + return ResponseFormat.builder() + .type(ResponseFormatType.JSON) + .jsonSchema(jsonSchema) + .build(); + } + + public static List toFeedbackScores(@NotNull ChatResponse chatResponse) { + var content = chatResponse.aiMessage().text(); + + JsonNode structuredResponse; + try { + structuredResponse = objectMapper.readTree(content); + if (!structuredResponse.isObject()) { + log.info("ChatResponse content returned into an empty JSON result"); + return Collections.emptyList(); + } + } catch (JsonProcessingException e) { + log.error("parsing LLM response into a JSON: {}", content, e); + return Collections.emptyList(); + } + + var spliterator = Spliterators.spliteratorUnknownSize(structuredResponse.fields(), + Spliterator.ORDERED | Spliterator.NONNULL); + + return StreamSupport.stream(spliterator, false) + .map(scoreMetric -> { + var scoreName = scoreMetric.getKey(); + var scoreNested = scoreMetric.getValue(); + + if (scoreNested == null || scoreNested.isMissingNode() || !scoreNested.has(SCORE_FIELD_NAME)) { + log.info("No score found for '{}' score in {}", scoreName, scoreNested); + return null; + } + + log.debug("new FeedbackScore[{}, {}, {}]", scoreName, + scoreNested.path(SCORE_FIELD_NAME).decimalValue(), + scoreNested.path(REASON_FIELD_NAME).asText()); + + return FeedbackScoreBatchItem.builder() + .name(scoreName) + .value(scoreNested.path(SCORE_FIELD_NAME).decimalValue()) + .reason(scoreNested.path(REASON_FIELD_NAME).asText()) + .source(ScoreSource.ONLINE_SCORING) + .build(); + }) + .filter(Objects::nonNull) + .toList(); + + } + + @AllArgsConstructor + public enum TraceSection { + INPUT("input."), + OUTPUT("output."), + METADATA("metadata."); + + final String prefix; + } + + @Builder(toBuilder = true) + public record MessageVariableMapping(TraceSection traceSection, String variableName, String jsonPath, + String valueToReplace) { + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEventListener.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEventListener.java index 0fa042b77c..5a15f71ea8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEventListener.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEventListener.java @@ -7,9 +7,9 @@ import com.comet.opik.domain.AutomationRuleEvaluatorService; import com.comet.opik.domain.ChatCompletionService; import com.comet.opik.domain.FeedbackScoreService; +import com.comet.opik.infrastructure.auth.RequestContext; import com.google.common.eventbus.EventBus; import com.google.common.eventbus.Subscribe; -import dev.ai4j.openai4j.chat.ChatCompletionRequest; import jakarta.inject.Inject; import lombok.extern.slf4j.Slf4j; import ru.vyarus.dropwizard.guice.module.installer.feature.eager.EagerSingleton; @@ -20,6 +20,8 @@ import java.util.UUID; import java.util.stream.Collectors; +import static com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode; + @EagerSingleton @Slf4j public class OnlineScoringEventListener { @@ -41,7 +43,7 @@ public OnlineScoringEventListener(EventBus eventBus, /** * Listen for trace batches to check for existent Automation Rules to score them. - * + *
* Automation Rule registers the percentage of traces to score, how to score them and so on. * * @param tracesBatch a traces batch with workspaceId and userName @@ -57,23 +59,24 @@ public void onTracesCreated(TracesCreated tracesBatch) { .collect(Collectors.toMap(entry -> "projectId: " + entry.getKey(), entry -> entry.getValue().size())); - log.debug("[OnlineScoring] Received traces for workspace '{}': {}", tracesBatch.workspaceId(), countMap); + log.debug("Received traces for workspace '{}': {}", tracesBatch.workspaceId(), countMap); Random random = new Random(System.currentTimeMillis()); // fetch automation rules per project tracesByProject.forEach((projectId, traces) -> { - log.debug("[OnlineScoring] Fetching evaluators for {} traces, project '{}' on workspace '{}'", + log.debug("Fetching evaluators for {} traces, project '{}' on workspace '{}'", traces.size(), projectId, tracesBatch.workspaceId()); List evaluators = ruleEvaluatorService.findAll( projectId, tracesBatch.workspaceId(), AutomationRuleEvaluatorType.LLM_AS_JUDGE); - log.info("[OnlineScoring] Found {} evaluators for project '{}' on workspace '{}'", evaluators.size(), + log.info("Found {} evaluators for project '{}' on workspace '{}'", evaluators.size(), projectId, tracesBatch.workspaceId()); // for each rule, sample traces and score them evaluators.forEach(evaluator -> traces.stream() .filter(e -> random.nextFloat() < evaluator.getSamplingRate()) - .forEach(trace -> score(trace, tracesBatch.workspaceId(), evaluator))); + .forEach(trace -> score(trace, evaluator.getCode(), tracesBatch.workspaceId(), + tracesBatch.userName()))); }); } @@ -81,21 +84,32 @@ public void onTracesCreated(TracesCreated tracesBatch) { * Use AI Proxy to score the trace and store it as a FeedbackScore. * If the evaluator has multiple score definitions, it calls the LLM once per score definition. * - * @param trace the trace to score - * @param workspaceId the workspace the trace belongs - * @param evaluator the automation rule to score the trace + * @param trace the trace to score + * @param evaluatorCode the automation rule to score the trace + * @param workspaceId the workspace the trace belongs */ - private void score(Trace trace, String workspaceId, AutomationRuleEvaluatorLlmAsJudge evaluator) { - // TODO prepare base request - var baseRequestBuilder = ChatCompletionRequest.builder() - .model(evaluator.getCode().model().name()) - .temperature(evaluator.getCode().model().temperature()) - .messages(LlmAsJudgeMessageRender.renderMessages(trace, evaluator.getCode())) - .build(); + private void score(Trace trace, LlmAsJudgeCode evaluatorCode, String workspaceId, + String userName) { + + var scoreRequest = OnlineScoringEngine.prepareLlmRequest(evaluatorCode, trace); + + var chatResponse = aiProxyService.scoreTrace(scoreRequest, evaluatorCode.model(), workspaceId); + + var scores = OnlineScoringEngine.toFeedbackScores(chatResponse).stream() + .map(item -> item.toBuilder() + .id(trace.id()) + .projectId(trace.projectId()) + .projectName(trace.projectName()) + .build()) + .toList(); - // TODO: call AI Proxy and parse response into 1+ FeedbackScore + log.info("Received {} scores for traceId '{}' in workspace '{}'. Storing them.", scores.size(), trace.id(), + workspaceId); - // TODO: store FeedbackScores + feedbackScoreService.scoreBatchOfTraces(scores) + .contextWrite(ctx -> ctx.put(RequestContext.USER_NAME, userName) + .put(RequestContext.WORKSPACE_ID, workspaceId)) + .block(); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java index a565581ffe..8bda68aacc 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/ChatCompletionService.java @@ -1,5 +1,6 @@ package com.comet.opik.domain; +import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; import com.comet.opik.domain.llmproviders.LlmProviderFactory; import com.comet.opik.domain.llmproviders.LlmProviderService; import com.comet.opik.infrastructure.LlmProviderClientConfig; @@ -7,6 +8,8 @@ import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.ai4j.openai4j.chat.ChatCompletionResponse; import dev.langchain4j.internal.RetryUtils; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.response.ChatResponse; import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.inject.Inject; import jakarta.inject.Singleton; @@ -88,6 +91,26 @@ public void createAndStreamResponse( request.model()); } + public ChatResponse scoreTrace(@NonNull ChatRequest chatRequest, + @NonNull AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters modelParameters, + @NonNull String workspaceId) { + var languageModelClient = llmProviderFactory.getLanguageModel(workspaceId, modelParameters); + + ChatResponse chatResponse; + try { + log.info("Initiating chat with model '{}' expecting structured response, workspaceId '{}'", + modelParameters.name(), workspaceId); + chatResponse = retryPolicy + .withRetry(() -> languageModelClient.chat(chatRequest)); + log.info("Completed chat with model '{}' expecting structured response, workspaceId '{}'", + modelParameters.name(), workspaceId); + return chatResponse; + } catch (RuntimeException runtimeException) { + log.error(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER, runtimeException); + throw new InternalServerErrorException(UNEXPECTED_ERROR_CALLING_LLM_PROVIDER); + } + } + private RetryUtils.RetryPolicy newRetryPolicy() { var retryPolicyBuilder = RetryUtils.retryPolicyBuilder(); Optional.ofNullable(llmProviderClientConfig.getMaxAttempts()).ifPresent(retryPolicyBuilder::maxAttempts); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java index 5c2b56cf34..06e02a4fad 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/TraceService.java @@ -107,8 +107,7 @@ public Mono create(@NonNull Trace trace) { new LockService.Lock(id, TRACE_KEY), Mono.defer(() -> insertTrace(trace, project, id))) .doOnSuccess(__ -> { - // forwards the trace with its actual projectId - var savedTrace = trace.toBuilder().projectId(project.id()).build(); + var savedTrace = trace.toBuilder().projectId(project.id()).projectName(projectName).build(); String workspaceId = ctx.get(RequestContext.WORKSPACE_ID); String userName = ctx.get(RequestContext.USER_NAME); @@ -157,7 +156,7 @@ private List bindTraceToProjectAndId(TraceBatch batch, List proj UUID id = trace.id() == null ? idGenerator.generateId() : trace.id(); IdGenerator.validateVersion(id, TRACE_KEY); - return trace.toBuilder().id(id).projectId(project.id()).build(); + return trace.toBuilder().id(id).projectId(project.id()).projectName(project.name()).build(); }) .toList(); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java index 145b597271..6671e534c1 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderClientGenerator.java @@ -1,11 +1,14 @@ package com.comet.opik.domain.llmproviders; +import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; import com.comet.opik.infrastructure.LlmProviderClientConfig; import dev.ai4j.openai4j.OpenAiClient; import dev.ai4j.openai4j.chat.ChatCompletionRequest; import dev.langchain4j.model.anthropic.internal.client.AnthropicClient; +import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.googleai.GoogleAiGeminiChatModel; import dev.langchain4j.model.googleai.GoogleAiGeminiStreamingChatModel; +import dev.langchain4j.model.openai.OpenAiChatModel; import lombok.NonNull; import lombok.RequiredArgsConstructor; import org.apache.commons.lang3.StringUtils; @@ -69,4 +72,25 @@ public GoogleAiGeminiStreamingChatModel newGeminiStreamingClient( return LlmProviderGeminiMapper.INSTANCE.toGeminiStreamingChatModel(apiKey, request, llmProviderClientConfig.getCallTimeout().toJavaDuration(), MAX_RETRIES); } + + public ChatLanguageModel newOpenAiChatLanguageModel(String apiKey, + AutomationRuleEvaluatorLlmAsJudge.@NonNull LlmAsJudgeModelParameters modelParameters) { + var builder = OpenAiChatModel.builder() + .modelName(modelParameters.name()) + .apiKey(apiKey) + .logRequests(true) + .logResponses(true); + + Optional.ofNullable(llmProviderClientConfig.getConnectTimeout()) + .ifPresent(connectTimeout -> builder.timeout(connectTimeout.toJavaDuration())); + + Optional.ofNullable(llmProviderClientConfig.getOpenAiClient()) + .map(LlmProviderClientConfig.OpenAiClientConfig::url) + .filter(StringUtils::isNotBlank) + .ifPresent(builder::baseUrl); + + Optional.ofNullable(modelParameters.temperature()).ifPresent(builder::temperature); + + return builder.build(); + } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java index 2a92e113ac..5e19cda445 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/llmproviders/LlmProviderFactory.java @@ -1,10 +1,12 @@ package com.comet.opik.domain.llmproviders; +import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; import com.comet.opik.api.LlmProvider; import com.comet.opik.domain.LlmProviderApiKeyService; import com.comet.opik.infrastructure.EncryptionUtils; import dev.ai4j.openai4j.chat.ChatCompletionModel; import dev.langchain4j.model.anthropic.AnthropicChatModelName; +import dev.langchain4j.model.chat.ChatLanguageModel; import jakarta.inject.Inject; import jakarta.inject.Singleton; import jakarta.ws.rs.BadRequestException; @@ -34,6 +36,16 @@ public LlmProviderService getService(@NonNull String workspaceId, @NonNull Strin }; } + public ChatLanguageModel getLanguageModel(@NonNull String workspaceId, + @NonNull AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters modelParameters) { + var llmProvider = getLlmProvider(modelParameters.name()); + var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider)); + + return switch (llmProvider) { + case LlmProvider.OPEN_AI -> llmProviderClientGenerator.newOpenAiChatLanguageModel(apiKey, modelParameters); + default -> throw new BadRequestException(String.format(ERROR_MODEL_NOT_SUPPORTED, modelParameters.name())); + }; + } /** * The agreed requirement is to resolve the LLM provider and its API key based on the model. */ diff --git a/apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000009_extend_feedback_source_type.sql b/apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000009_extend_feedback_source_type.sql new file mode 100644 index 0000000000..c08a7bb1bd --- /dev/null +++ b/apps/opik-backend/src/main/resources/liquibase/db-app-analytics/migrations/000009_extend_feedback_source_type.sql @@ -0,0 +1,7 @@ +--liquibase formatted sql +--changeset DanielAugusto:000009_extend_feedback_source_type + +ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.feedback_scores + MODIFY COLUMN `source` Enum8('sdk', 'ui', 'online_scoring'); + +--rollback ALTER TABLE ${ANALYTICS_DB_DATABASE_NAME}.feedback_scores MODIFY COLUMN `source` Enum8('sdk', 'ui'); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/LlmAsJudgeMessageRenderTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/LlmAsJudgeMessageRenderTest.java deleted file mode 100644 index 6afcd44fe7..0000000000 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/LlmAsJudgeMessageRenderTest.java +++ /dev/null @@ -1,142 +0,0 @@ -package com.comet.opik.api.resources.v1.events; - -import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; -import com.comet.opik.api.Trace; -import com.comet.opik.domain.AutomationRuleEvaluatorService; -import com.comet.opik.domain.ChatCompletionService; -import com.comet.opik.domain.FeedbackScoreService; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.common.eventbus.EventBus; -import dev.ai4j.openai4j.chat.Role; -import dev.ai4j.openai4j.chat.UserMessage; -import lombok.extern.slf4j.Slf4j; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInstance; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; - -import static org.assertj.core.api.Assertions.assertThat; - -@Slf4j -@TestInstance(TestInstance.Lifecycle.PER_CLASS) -@DisplayName("LlmAsJudge Message Render") -public class LlmAsJudgeMessageRenderTest { - @Mock - AutomationRuleEvaluatorService ruleEvaluatorService; - @Mock - ChatCompletionService aiProxyService; - @Mock - FeedbackScoreService feedbackScoreService; - @Mock - EventBus eventBus; - OnlineScoringEventListener onlineScoringEventListener; - - AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode evaluatorCode; - Trace trace; - - String messageToTest = "Summary: {{summary}}\\nInstruction: {{instruction}}\\n\\n"; - String testEvaluator = """ - { - "model": { - "name": "gpt-4o", - "temperature": 0.3 - }, - "messages": [ - { - "role": "USER", - "content": "%s" - }, - { - "role": "SYSTEM", - "content": "You're a helpful AI, be cordial." - } - ], - "variables": { - "summary": "input.questions.question1", - "instruction": "output.output", - "nonUsed": "input.questions.question2", - "toFail1": "metadata.nonexistent.path" - }, - "schema": [ - { "name": "Relevance", "type": "INTEGER", "description": "Relevance of the summary" }, - { "name": "Conciseness", "type": "DOUBLE", "description": "Conciseness of the summary" }, - { "name": "Technical Accuracy", "type": "BOOLEAN", "description": "Technical accuracy of the summary" } - ] - } - """ - .formatted(messageToTest).trim(); - String summaryStr = "What was the approach to experimenting with different data mixtures?"; - String outputStr = "The study employed a systematic approach to experiment with varying data mixtures by manipulating the proportions and sources of datasets used for model training."; - String input = """ - { - "questions": { - "question1": "%s", - "question2": "Whatever, we wont use it anyway" - }, - "pdf_url": "https://arxiv.org/pdf/2406.04744", - "title": "CRAG -- Comprehensive RAG Benchmark" - } - """.formatted(summaryStr).trim(); - String output = """ - { - "output": "%s" - } - """.formatted(outputStr).trim(); - - @BeforeEach - void setUp() throws JsonProcessingException { - MockitoAnnotations.openMocks(this); - Mockito.doNothing().when(eventBus).register(Mockito.any()); - onlineScoringEventListener = new OnlineScoringEventListener(eventBus, ruleEvaluatorService, - aiProxyService, feedbackScoreService); - - var mapper = new ObjectMapper(); - evaluatorCode = mapper.readValue(testEvaluator, AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode.class); - trace = Trace.builder().input(mapper.readTree(input)).output(mapper.readTree(output)).build(); - } - - @Test - @DisplayName("parse variable mapping into a usable one") - void when__parseRuleVariables() { - var variableMappings = LlmAsJudgeMessageRender.variableMapping(evaluatorCode.variables()); - - assertThat(variableMappings).hasSize(4); - - var varSummary = variableMappings.get(0); - assertThat(varSummary.traceSection()).isEqualTo(LlmAsJudgeMessageRender.TraceSection.INPUT); - assertThat(varSummary.jsonPath()).isEqualTo("$.questions.question1"); - - var varInstruction = variableMappings.get(1); - assertThat(varInstruction.traceSection()).isEqualTo(LlmAsJudgeMessageRender.TraceSection.OUTPUT); - assertThat(varInstruction.jsonPath()).isEqualTo("$.output"); - - var varNonUsed = variableMappings.get(2); - assertThat(varNonUsed.traceSection()).isEqualTo(LlmAsJudgeMessageRender.TraceSection.INPUT); - assertThat(varNonUsed.jsonPath()).isEqualTo("$.questions.question2"); - - var varToFail = variableMappings.get(3); - assertThat(varToFail.traceSection()).isEqualTo(LlmAsJudgeMessageRender.TraceSection.METADATA); - assertThat(varToFail.jsonPath()).isEqualTo("$.nonexistent.path"); - } - - @Test - @DisplayName("render message templates with a trace") - void when__renderTemplate() { - var renderedMessages = LlmAsJudgeMessageRender.renderMessages(trace, evaluatorCode); - - assertThat(renderedMessages).hasSize(2); - - var userMessage = (UserMessage) renderedMessages.get(0); - assertThat(userMessage.role()).isEqualTo(Role.USER); - assertThat(userMessage.content().toString()).contains(summaryStr); - assertThat(userMessage.content().toString()).contains(outputStr); - - var systemMessage = renderedMessages.get(1); - assertThat(systemMessage.role()).isEqualTo(Role.SYSTEM); - } - -} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java new file mode 100644 index 0000000000..56f4db3359 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java @@ -0,0 +1,241 @@ +package com.comet.opik.api.resources.v1.events; + +import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; +import com.comet.opik.api.LlmAsJudgeOutputSchemaType; +import com.comet.opik.api.ScoreSource; +import com.comet.opik.api.Trace; +import com.comet.opik.domain.AutomationRuleEvaluatorService; +import com.comet.opik.domain.ChatCompletionService; +import com.comet.opik.domain.FeedbackScoreService; +import com.comet.opik.podam.PodamFactoryUtils; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.eventbus.EventBus; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.SystemMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.chat.request.json.JsonBooleanSchema; +import dev.langchain4j.model.chat.request.json.JsonIntegerSchema; +import dev.langchain4j.model.chat.request.json.JsonNumberSchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.response.ChatResponse; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import uk.co.jemos.podam.api.PodamFactory; + +import java.math.BigDecimal; +import java.util.List; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +@Slf4j +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@DisplayName("LlmAsJudge Message Render") +public class OnlineScoringEngineTest { + @Mock + AutomationRuleEvaluatorService ruleEvaluatorService; + @Mock + ChatCompletionService aiProxyService; + @Mock + FeedbackScoreService feedbackScoreService; + @Mock + EventBus eventBus; + OnlineScoringEventListener onlineScoringEventListener; + + private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); + + AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode evaluatorCode; + Trace trace; + + String messageToTest = "Summary: {{summary}}\\nInstruction: {{instruction}}\\n\\n"; + String testEvaluator = """ + { + "model": { "name": "gpt-4o", "temperature": 0.3 }, + "messages": [ + { "role": "USER", "content": "%s" }, + { "role": "SYSTEM", "content": "You're a helpful AI, be cordial." } + ], + "variables": { + "summary": "input.questions.question1", + "instruction": "output.output", + "nonUsed": "input.questions.question2", + "toFail1": "metadata.nonexistent.path" + }, + "schema": [ + { "name": "Relevance", "type": "INTEGER", "description": "Relevance of the summary" }, + { "name": "Conciseness", "type": "DOUBLE", "description": "Conciseness of the summary" }, + { "name": "Technical Accuracy", "type": "BOOLEAN", "description": "Technical accuracy of the summary" } + ] + } + """ + .formatted(messageToTest).trim(); + String summaryStr = "What was the approach to experimenting with different data mixtures?"; + String outputStr = "The study employed a systematic approach to experiment with varying data mixtures by manipulating the proportions and sources of datasets used for model training."; + String input = """ + { + "questions": { + "question1": "%s", + "question2": "Whatever, we wont use it anyway" + }, + "pdf_url": "https://arxiv.org/pdf/2406.04744", + "title": "CRAG -- Comprehensive RAG Benchmark" + } + """.formatted(summaryStr).trim(); + String output = """ + { + "output": "%s" + } + """.formatted(outputStr).trim(); + + @BeforeEach + void setUp() throws JsonProcessingException { + MockitoAnnotations.openMocks(this); + Mockito.doNothing().when(eventBus).register(Mockito.any()); + onlineScoringEventListener = new OnlineScoringEventListener(eventBus, ruleEvaluatorService, + aiProxyService, feedbackScoreService); + + var mapper = new ObjectMapper(); + evaluatorCode = mapper.readValue(testEvaluator, AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode.class); + trace = Trace.builder().input(mapper.readTree(input)).output(mapper.readTree(output)).build(); + } + + @Test + @DisplayName("parse variable mapping into a usable one") + void testVariableMapping() { + var variableMappings = OnlineScoringEngine.toVariableMapping(evaluatorCode.variables()); + + assertThat(variableMappings).hasSize(4); + + var varSummary = variableMappings.get(0); + assertThat(varSummary.traceSection()).isEqualTo(OnlineScoringEngine.TraceSection.INPUT); + assertThat(varSummary.jsonPath()).isEqualTo("$.questions.question1"); + + var varInstruction = variableMappings.get(1); + assertThat(varInstruction.traceSection()).isEqualTo(OnlineScoringEngine.TraceSection.OUTPUT); + assertThat(varInstruction.jsonPath()).isEqualTo("$.output"); + + var varNonUsed = variableMappings.get(2); + assertThat(varNonUsed.traceSection()).isEqualTo(OnlineScoringEngine.TraceSection.INPUT); + assertThat(varNonUsed.jsonPath()).isEqualTo("$.questions.question2"); + + var varToFail = variableMappings.get(3); + assertThat(varToFail.traceSection()).isEqualTo(OnlineScoringEngine.TraceSection.METADATA); + assertThat(varToFail.jsonPath()).isEqualTo("$.nonexistent.path"); + } + + @Test + @DisplayName("render message templates with a trace") + void testRenderTemplate() { + var renderedMessages = OnlineScoringEngine.renderMessages(evaluatorCode.messages(), evaluatorCode.variables(), + trace); + + assertThat(renderedMessages).hasSize(2); + + var userMessage = renderedMessages.get(0); + assertThat(userMessage.getClass()).isEqualTo(UserMessage.class); + assertThat(((UserMessage) userMessage).singleText()).contains(summaryStr); + assertThat(((UserMessage) userMessage).singleText()).contains(outputStr); + + var systemMessage = renderedMessages.get(1); + assertThat(systemMessage.getClass()).isEqualTo(SystemMessage.class); + } + + @Test + @DisplayName("create a structured output response format given an Automation Rule Evaluator schema input") + void testToResponseFormat() { + // creates an entry for each possible output schema type + var inputIntSchema = factory.manufacturePojo(AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeOutputSchema.class) + .toBuilder().type(LlmAsJudgeOutputSchemaType.INTEGER).build(); + var inputBoolSchema = factory.manufacturePojo(AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeOutputSchema.class) + .toBuilder().type(LlmAsJudgeOutputSchemaType.BOOLEAN).build(); + var inputDoubleSchema = factory.manufacturePojo(AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeOutputSchema.class) + .toBuilder().type(LlmAsJudgeOutputSchemaType.DOUBLE).build(); + var schema = List.of(inputIntSchema, inputBoolSchema, inputDoubleSchema); + + var responseFormat = OnlineScoringEngine.toResponseFormat(schema); + + var schemaRoot = (JsonObjectSchema) responseFormat.jsonSchema().rootElement(); + assertThat(schemaRoot.properties()).hasSize(schema.size()); + assertThat(schemaRoot.required()).containsOnly(inputBoolSchema.name(), inputDoubleSchema.name(), + inputIntSchema.name()); + + var parsedIntSchema = (JsonObjectSchema) schemaRoot.properties().get(inputIntSchema.name()); + assertThat(parsedIntSchema.description()).isEqualTo(inputIntSchema.description()); + assertThat(parsedIntSchema.required()).containsOnly(OnlineScoringEngine.SCORE_FIELD_NAME, + OnlineScoringEngine.REASON_FIELD_NAME); + assertThat(parsedIntSchema.properties().get(OnlineScoringEngine.SCORE_FIELD_NAME).getClass()) + .isEqualTo(JsonIntegerSchema.class); + + var parsedBoolSchema = (JsonObjectSchema) schemaRoot.properties().get(inputBoolSchema.name()); + assertThat(parsedBoolSchema.description()).isEqualTo(inputBoolSchema.description()); + assertThat(parsedBoolSchema.required()).containsOnly(OnlineScoringEngine.SCORE_FIELD_NAME, + OnlineScoringEngine.REASON_FIELD_NAME); + assertThat(parsedBoolSchema.properties().get(OnlineScoringEngine.SCORE_FIELD_NAME).getClass()) + .isEqualTo(JsonBooleanSchema.class); + + var parsedDoubleSchema = (JsonObjectSchema) schemaRoot.properties().get(inputDoubleSchema.name()); + assertThat(parsedDoubleSchema.description()).isEqualTo(inputDoubleSchema.description()); + assertThat(parsedDoubleSchema.required()).containsOnly(OnlineScoringEngine.SCORE_FIELD_NAME, + OnlineScoringEngine.REASON_FIELD_NAME); + assertThat(parsedDoubleSchema.properties().get(OnlineScoringEngine.SCORE_FIELD_NAME).getClass()) + .isEqualTo(JsonNumberSchema.class); + + } + + private static Stream feedbackParsingArguments() { + var validAiMsgTxt = "{\"Relevance\":{\"score\":5,\"reason\":\"The summary directly addresses the approach taken in the study by mentioning the systematic experimentation with varying data mixtures and the manipulation of proportions and sources.\"}," + + + "\"Conciseness\":{\"score\":4,\"reason\":\"The summary is mostly concise but could be slightly more streamlined by removing redundant phrases.\"}," + + + "\"Technical Accuracy\":{\"score\":0,\"reason\":\"The summary accurately describes the experimental approach involving data mixtures, proportions, and sources, reflecting the technical details of the study.\"}}"; + var invalidAiMsgTxt = "a" + validAiMsgTxt; + + var validJson = arguments(validAiMsgTxt, 3); + var invalidJson = arguments(invalidAiMsgTxt, 0); + var emptyJson = arguments("", 0); + + return Stream.of(validJson, invalidJson, emptyJson); + } + + @ParameterizedTest + @MethodSource("feedbackParsingArguments") + @DisplayName("parse a OnlineScoring ChatResponse into Feedback Scores") + void testParseResponseIntoFeedbacks(String aiMessage, Integer expectedResults) { + var chatResponse = ChatResponse.builder().aiMessage(AiMessage.from(aiMessage)).build(); + var feedbackScores = OnlineScoringEngine.toFeedbackScores(chatResponse); + + assertThat(feedbackScores).hasSize(expectedResults); + + if (expectedResults > 0) { + var relevanceScore = feedbackScores.get(0); + assertThat(relevanceScore.name()).isEqualTo("Relevance"); + assertThat(relevanceScore.value()).isEqualTo(new BigDecimal(5)); + assertThat(relevanceScore.reason()).startsWith("The summary directly "); + assertThat(relevanceScore.source()).isEqualTo(ScoreSource.ONLINE_SCORING); + + var concisenessScore = feedbackScores.get(1); + assertThat(concisenessScore.name()).isEqualTo("Conciseness"); + assertThat(concisenessScore.value()).isEqualTo(new BigDecimal(4)); + assertThat(concisenessScore.reason()).startsWith("The summary is mostly "); + assertThat(concisenessScore.source()).isEqualTo(ScoreSource.ONLINE_SCORING); + + var techAccScore = feedbackScores.get(2); + assertThat(techAccScore.name()).isEqualTo("Technical Accuracy"); + assertThat(techAccScore.value()).isEqualTo(new BigDecimal(0)); + assertThat(techAccScore.reason()).startsWith("The summary accurately "); + assertThat(techAccScore.source()).isEqualTo(ScoreSource.ONLINE_SCORING); + + } + } +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEventListenerTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEventListenerTest.java deleted file mode 100644 index e9f120caef..0000000000 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEventListenerTest.java +++ /dev/null @@ -1,139 +0,0 @@ -package com.comet.opik.api.resources.v1.events; - -import com.comet.opik.api.AutomationRuleEvaluatorLlmAsJudge; -import com.comet.opik.api.Trace; -import com.comet.opik.api.resources.utils.AuthTestUtils; -import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; -import com.comet.opik.api.resources.utils.ClientSupportUtils; -import com.comet.opik.api.resources.utils.MigrationUtils; -import com.comet.opik.api.resources.utils.MySQLContainerUtils; -import com.comet.opik.api.resources.utils.RedisContainerUtils; -import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; -import com.comet.opik.api.resources.utils.WireMockUtils; -import com.comet.opik.api.resources.utils.resources.AutomationRuleEvaluatorResourceClient; -import com.comet.opik.api.resources.utils.resources.ProjectResourceClient; -import com.comet.opik.api.resources.utils.resources.TraceResourceClient; -import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; -import com.comet.opik.podam.PodamFactoryUtils; -import com.redis.testcontainers.RedisContainer; -import lombok.extern.slf4j.Slf4j; -import org.jdbi.v3.core.Jdbi; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Nested; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInstance; -import org.junit.jupiter.api.extension.RegisterExtension; -import org.testcontainers.clickhouse.ClickHouseContainer; -import org.testcontainers.containers.MySQLContainer; -import org.testcontainers.lifecycle.Startables; -import ru.vyarus.dropwizard.guice.test.ClientSupport; -import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; -import uk.co.jemos.podam.api.PodamFactory; - -import java.util.UUID; - -import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; -import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; - -@Slf4j -@TestInstance(TestInstance.Lifecycle.PER_CLASS) -@DisplayName("Online Scoring Event Listener") -public class OnlineScoringEventListenerTest { - - private static final String API_KEY = UUID.randomUUID().toString(); - private static final String USER = UUID.randomUUID().toString(); - private static final String WORKSPACE_ID = UUID.randomUUID().toString(); - private static final String WORKSPACE_NAME = "workspace-" + UUID.randomUUID(); - - private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); - - private static final MySQLContainer MYSQL = MySQLContainerUtils.newMySQLContainer(); - - private static final ClickHouseContainer CLICKHOUSE = ClickHouseContainerUtils.newClickHouseContainer(); - - @RegisterExtension - private static final TestDropwizardAppExtension app; - - private static final WireMockUtils.WireMockRuntime wireMock; - - static { - Startables.deepStart(MYSQL, CLICKHOUSE, REDIS).join(); - - wireMock = WireMockUtils.startWireMock(); - - DatabaseAnalyticsFactory databaseAnalyticsFactory = ClickHouseContainerUtils - .newDatabaseAnalyticsFactory(CLICKHOUSE, DATABASE_NAME); - - app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( - MYSQL.getJdbcUrl(), databaseAnalyticsFactory, wireMock.runtimeInfo(), REDIS.getRedisURI()); - } - - private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); - - private String baseURI; - private ClientSupport client; - private TraceResourceClient traceResourceClient; - private AutomationRuleEvaluatorResourceClient evaluatorResourceClient; - private ProjectResourceClient projectResourceClient; - - @BeforeAll - void setUpAll(ClientSupport client, Jdbi jdbi) throws Exception { - - MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); - - try (var connection = CLICKHOUSE.createConnection("")) { - MigrationUtils.runDbMigration(connection, CLICKHOUSE_CHANGELOG_FILE, - ClickHouseContainerUtils.migrationParameters()); - } - - this.baseURI = "http://localhost:%d".formatted(client.getPort()); - this.client = client; - - ClientSupportUtils.config(client); - - mockTargetWorkspace(API_KEY, WORKSPACE_NAME, WORKSPACE_ID); - - this.traceResourceClient = new TraceResourceClient(this.client, baseURI); - this.evaluatorResourceClient = new AutomationRuleEvaluatorResourceClient(this.client, baseURI); - this.projectResourceClient = new ProjectResourceClient(this.client, baseURI, factory); - } - - @AfterAll - void tearDownAll() { - wireMock.server().stop(); - } - - private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { - AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); - } - - @Nested - @TestInstance(TestInstance.Lifecycle.PER_CLASS) - class TracesCreatedEvent { - - @Test - @DisplayName("when a new trace is created, OnlineScoring should see it within a event") - void when__newTracesIsCreated__onlineScoringShouldKnow() { - var projectName = factory.manufacturePojo(String.class); - var projectId = projectResourceClient.createProject(projectName, API_KEY, WORKSPACE_NAME); - - var evaluator = factory.manufacturePojo(AutomationRuleEvaluatorLlmAsJudge.class) - .toBuilder().projectId(null).build(); - - evaluatorResourceClient.createEvaluator(evaluator, projectId, WORKSPACE_NAME, API_KEY); - - var trace = factory.manufacturePojo(Trace.class).toBuilder() - .projectName(projectName) - .build(); - - UUID traceId = traceResourceClient.createTrace(trace, API_KEY, WORKSPACE_NAME); - - Trace returnTrace = traceResourceClient.getById(traceId, WORKSPACE_NAME, API_KEY); - - // TODO: run the actual test checking for if we have a FeedbackScore by the end. Prob mocking AI Proxy. - } - } - -}