diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java index a69944a29d..98066ee150 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngine.java @@ -3,6 +3,7 @@ import com.comet.opik.api.FeedbackScoreBatchItem; import com.comet.opik.api.ScoreSource; import com.comet.opik.api.Trace; +import com.comet.opik.utils.MustacheUtils; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; @@ -26,7 +27,6 @@ import lombok.Builder; import lombok.experimental.UtilityClass; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.text.StringSubstitutor; import java.util.Arrays; import java.util.Collections; @@ -103,14 +103,11 @@ static List renderMessages(List templateMessages .collect( Collectors.toMap(MessageVariableMapping::variableName, MessageVariableMapping::valueToReplace)); - // will convert all '{{key}}' into 'value' - // TODO: replace with Mustache Java to be in confirm with FE - var templateRenderer = new StringSubstitutor(replacements, "{{", "}}"); - // render the message templates from evaluator rule return templateMessages.stream() .map(templateMessage -> { - var renderedMessage = templateRenderer.replace(templateMessage.content()); + // will convert all '{{key}}' into 'value' + var renderedMessage = MustacheUtils.render(templateMessage.content(), replacements); return switch (templateMessage.role()) { case USER -> UserMessage.from(renderedMessage); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java index bc6accd75c..47227c3d52 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -6,7 +6,7 @@ import com.comet.opik.api.PromptVersion.PromptVersionPage; import com.comet.opik.api.error.EntityAlreadyExistsException; import com.comet.opik.infrastructure.auth.RequestContext; -import com.comet.opik.utils.MustacheVariableExtractor; +import com.comet.opik.utils.MustacheUtils; import com.google.inject.ImplementedBy; import io.dropwizard.jersey.errors.ErrorMessage; import jakarta.inject.Inject; @@ -396,7 +396,7 @@ private Set getVariables(String template) { return null; } - return MustacheVariableExtractor.extractVariables(template); + return MustacheUtils.extractVariables(template); } private EntityAlreadyExistsException newConflict(String alreadyExists) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/PromptVersionColumnMapper.java b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/PromptVersionColumnMapper.java index e2eb33afc5..e27462769f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/PromptVersionColumnMapper.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/infrastructure/db/PromptVersionColumnMapper.java @@ -2,7 +2,7 @@ import com.comet.opik.api.PromptVersion; import com.comet.opik.utils.JsonUtils; -import com.comet.opik.utils.MustacheVariableExtractor; +import com.comet.opik.utils.MustacheUtils; import com.fasterxml.jackson.databind.JsonNode; import org.jdbi.v3.core.mapper.ColumnMapper; import org.jdbi.v3.core.statement.StatementContext; @@ -35,7 +35,7 @@ private PromptVersion mapObject(JsonNode jsonNode) { .template(jsonNode.get("template").asText()) .metadata(jsonNode.get("metadata")) .changeDescription(jsonNode.get("change_description").asText()) - .variables(MustacheVariableExtractor.extractVariables(jsonNode.get("template").asText())) + .variables(MustacheUtils.extractVariables(jsonNode.get("template").asText())) .createdAt(Instant.from(FORMATTER.parse(jsonNode.get("created_at").asText()))) .createdBy(jsonNode.get("created_by").asText()) .build(); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/utils/MustacheVariableExtractor.java b/apps/opik-backend/src/main/java/com/comet/opik/utils/MustacheUtils.java similarity index 70% rename from apps/opik-backend/src/main/java/com/comet/opik/utils/MustacheVariableExtractor.java rename to apps/opik-backend/src/main/java/com/comet/opik/utils/MustacheUtils.java index 798e2b3ac2..7544c9bfbb 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/utils/MustacheVariableExtractor.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/utils/MustacheUtils.java @@ -7,14 +7,19 @@ import com.github.mustachejava.codes.ValueCode; import lombok.experimental.UtilityClass; +import java.io.IOException; import java.io.StringReader; +import java.io.StringWriter; +import java.io.UncheckedIOException; +import java.io.Writer; import java.util.HashSet; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; @UtilityClass -public class MustacheVariableExtractor { +public class MustacheUtils { public static final MustacheFactory MF = new DefaultMustacheFactory(); @@ -31,6 +36,18 @@ public static Set extractVariables(String template) { return variables; } + public static String render(String template, Map context) { + + Mustache mustache = MF.compile(new StringReader(template), "template"); + + try (Writer writer = mustache.execute(new StringWriter(), context)) { + writer.flush(); + return writer.toString(); + } catch (IOException e) { + throw new UncheckedIOException("Failed to render template", e); + } + } + private static void collectVariables(Code[] codes, Set variables) { for (Code code : codes) { if (Objects.requireNonNull(code) instanceof ValueCode valueCode) { diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java index 56f4db3359..f43580c33b 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/OnlineScoringEngineTest.java @@ -24,12 +24,13 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.jupiter.MockitoExtension; import uk.co.jemos.podam.api.PodamFactory; import java.math.BigDecimal; @@ -42,7 +43,9 @@ @Slf4j @TestInstance(TestInstance.Lifecycle.PER_CLASS) @DisplayName("LlmAsJudge Message Render") -public class OnlineScoringEngineTest { +@ExtendWith(MockitoExtension.class) +class OnlineScoringEngineTest { + @Mock AutomationRuleEvaluatorService ruleEvaluatorService; @Mock @@ -98,14 +101,37 @@ public class OnlineScoringEngineTest { } """.formatted(outputStr).trim(); + String edgeCaseTemplate = "Summary: {{summary}}\\nInstruction: {{ instruction }}\\n\\n"; + String testEvaluatorEdgeCase = """ + { + "model": { "name": "gpt-4o", "temperature": 0.3 }, + "messages": [ + { "role": "USER", "content": "%s" }, + { "role": "SYSTEM", "content": "You're a helpful AI, be cordial." } + ], + "variables": { + "summary": "input.questions.question1", + "instruction": "output.output", + "nonUsed": "input.questions.question2", + "toFail1": "metadata.nonexistent.path" + }, + "schema": [ + { "name": "Relevance", "type": "INTEGER", "description": "Relevance of the summary" }, + { "name": "Conciseness", "type": "DOUBLE", "description": "Conciseness of the summary" }, + { "name": "Technical Accuracy", "type": "BOOLEAN", "description": "Technical accuracy of the summary" } + ] + } + """ + .formatted(edgeCaseTemplate).trim(); + + private ObjectMapper mapper = new ObjectMapper(); + @BeforeEach void setUp() throws JsonProcessingException { - MockitoAnnotations.openMocks(this); Mockito.doNothing().when(eventBus).register(Mockito.any()); onlineScoringEventListener = new OnlineScoringEventListener(eventBus, ruleEvaluatorService, aiProxyService, feedbackScoreService); - var mapper = new ObjectMapper(); evaluatorCode = mapper.readValue(testEvaluator, AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode.class); trace = Trace.builder().input(mapper.readTree(input)).output(mapper.readTree(output)).build(); } @@ -238,4 +264,25 @@ void testParseResponseIntoFeedbacks(String aiMessage, Integer expectedResults) { } } + + @Test + @DisplayName("render a message template with edge cases") + void testRenderEdgeCaseTemplate() throws JsonProcessingException { + + var evaluatorEdgeCase = mapper.readValue(testEvaluatorEdgeCase, + AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeCode.class); + + var renderedMessages = OnlineScoringEngine.renderMessages(evaluatorEdgeCase.messages(), + evaluatorEdgeCase.variables(), trace); + + assertThat(renderedMessages).hasSize(2); + + var userMessage = renderedMessages.get(0); + assertThat(userMessage.getClass()).isEqualTo(UserMessage.class); + assertThat(((UserMessage) userMessage).singleText()).contains(summaryStr); + assertThat(((UserMessage) userMessage).singleText()).contains(outputStr); + + var systemMessage = renderedMessages.get(1); + assertThat(systemMessage.getClass()).isEqualTo(SystemMessage.class); + } }