Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-751] Use Mustache for online scoring #1043

Merged
merged 10 commits into from
Jan 15, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.ScoreSource;
import com.comet.opik.api.Trace;
import com.comet.opik.utils.MustacheUtils;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -26,7 +27,6 @@
import lombok.Builder;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.text.StringSubstitutor;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -103,14 +103,11 @@ static List<ChatMessage> renderMessages(List<LlmAsJudgeMessage> templateMessages
.collect(
Collectors.toMap(MessageVariableMapping::variableName, MessageVariableMapping::valueToReplace));

// will convert all '{{key}}' into 'value'
// TODO: replace with Mustache Java to be in confirm with FE
var templateRenderer = new StringSubstitutor(replacements, "{{", "}}");

// render the message templates from evaluator rule
return templateMessages.stream()
.map(templateMessage -> {
var renderedMessage = templateRenderer.replace(templateMessage.content());
// will convert all '{{key}}' into 'value'
var renderedMessage = MustacheUtils.render(templateMessage.content(), replacements);

return switch (templateMessage.role()) {
case USER -> UserMessage.from(renderedMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import com.comet.opik.api.PromptVersion.PromptVersionPage;
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.utils.MustacheVariableExtractor;
import com.comet.opik.utils.MustacheUtils;
import com.google.inject.ImplementedBy;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
Expand Down Expand Up @@ -396,7 +396,7 @@ private Set<String> getVariables(String template) {
return null;
}

return MustacheVariableExtractor.extractVariables(template);
return MustacheUtils.extractVariables(template);
}

private EntityAlreadyExistsException newConflict(String alreadyExists) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.comet.opik.api.PromptVersion;
import com.comet.opik.utils.JsonUtils;
import com.comet.opik.utils.MustacheVariableExtractor;
import com.comet.opik.utils.MustacheUtils;
import com.fasterxml.jackson.databind.JsonNode;
import org.jdbi.v3.core.mapper.ColumnMapper;
import org.jdbi.v3.core.statement.StatementContext;
Expand Down Expand Up @@ -35,7 +35,7 @@ private PromptVersion mapObject(JsonNode jsonNode) {
.template(jsonNode.get("template").asText())
.metadata(jsonNode.get("metadata"))
.changeDescription(jsonNode.get("change_description").asText())
.variables(MustacheVariableExtractor.extractVariables(jsonNode.get("template").asText()))
.variables(MustacheUtils.extractVariables(jsonNode.get("template").asText()))
.createdAt(Instant.from(FORMATTER.parse(jsonNode.get("created_at").asText())))
.createdBy(jsonNode.get("created_by").asText())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import com.github.mustachejava.codes.ValueCode;
import lombok.experimental.UtilityClass;

import java.io.IOException;
import java.io.StringReader;
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.io.Writer;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

@UtilityClass
public class MustacheVariableExtractor {
public class MustacheUtils {

public static final MustacheFactory MF = new DefaultMustacheFactory();

Expand All @@ -31,6 +36,18 @@ public static Set<String> extractVariables(String template) {
return variables;
}

public static String render(String template, Map<String, ?> context) {

Mustache mustache = MF.compile(new StringReader(template), "template");

try (Writer writer = mustache.execute(new StringWriter(), context)) {
writer.flush();
return writer.toString();
} catch (IOException e) {
throw new UncheckedIOException("Failed to render template", e);
}
}

private static void collectVariables(Code[] codes, Set<String> variables) {
for (Code code : codes) {
if (Objects.requireNonNull(code) instanceof ValueCode valueCode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
@Slf4j
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@DisplayName("LlmAsJudge Message Render")
public class OnlineScoringEngineTest {
class OnlineScoringEngineTest {

@Mock
AutomationRuleEvaluatorService ruleEvaluatorService;
@Mock
Expand Down
Loading