Skip to content

Commit

Permalink
[OPIK-751] Use mustache for online scoring
Browse files Browse the repository at this point in the history
  • Loading branch information
thiagohora committed Jan 14, 2025
1 parent 4490aaa commit c3a66b0
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.comet.opik.api.FeedbackScoreBatchItem;
import com.comet.opik.api.ScoreSource;
import com.comet.opik.api.Trace;
import com.comet.opik.utils.MustacheUtils;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -26,7 +27,6 @@
import lombok.Builder;
import lombok.experimental.UtilityClass;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.text.StringSubstitutor;

import java.util.Arrays;
import java.util.Collections;
Expand All @@ -46,11 +46,11 @@
@Slf4j
public class OnlineScoringEngine {

final String SCORE_FIELD_NAME = "score";
final String REASON_FIELD_NAME = "reason";
final String SCORE_FIELD_DESCRIPTION = "the score for ";
final String REASON_FIELD_DESCRIPTION = "the reason for the score for ";
final String DEFAULT_SCHEMA_NAME = "scoring_schema";
static final String SCORE_FIELD_NAME = "score";
static final String REASON_FIELD_NAME = "reason";
static final String SCORE_FIELD_DESCRIPTION = "the score for ";
static final String REASON_FIELD_DESCRIPTION = "the reason for the score for ";
static final String DEFAULT_SCHEMA_NAME = "scoring_schema";

/**
* Prepare a request to a LLM-as-Judge evaluator (a ChatLanguageModel) rendering the template messages with
Expand Down Expand Up @@ -103,14 +103,10 @@ static List<ChatMessage> renderMessages(List<LlmAsJudgeMessage> templateMessages
.collect(
Collectors.toMap(MessageVariableMapping::variableName, MessageVariableMapping::valueToReplace));

// will convert all '{{key}}' into 'value'
// TODO: replace with Mustache Java to be in confirm with FE
var templateRenderer = new StringSubstitutor(replacements, "{{", "}}");

// render the message templates from evaluator rule
return templateMessages.stream()
.map(templateMessage -> {
var renderedMessage = templateRenderer.replace(templateMessage.content());
var renderedMessage = MustacheUtils.render(templateMessage.content(), replacements);

return switch (templateMessage.role()) {
case USER -> UserMessage.from(renderedMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import com.comet.opik.api.PromptVersion.PromptVersionPage;
import com.comet.opik.api.error.EntityAlreadyExistsException;
import com.comet.opik.infrastructure.auth.RequestContext;
import com.comet.opik.utils.MustacheVariableExtractor;
import com.comet.opik.utils.MustacheUtils;
import com.google.inject.ImplementedBy;
import io.dropwizard.jersey.errors.ErrorMessage;
import jakarta.inject.Inject;
Expand Down Expand Up @@ -396,7 +396,7 @@ private Set<String> getVariables(String template) {
return null;
}

return MustacheVariableExtractor.extractVariables(template);
return MustacheUtils.extractVariables(template);
}

private EntityAlreadyExistsException newConflict(String alreadyExists) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.comet.opik.api.PromptVersion;
import com.comet.opik.utils.JsonUtils;
import com.comet.opik.utils.MustacheVariableExtractor;
import com.comet.opik.utils.MustacheUtils;
import com.fasterxml.jackson.databind.JsonNode;
import org.jdbi.v3.core.mapper.ColumnMapper;
import org.jdbi.v3.core.statement.StatementContext;
Expand Down Expand Up @@ -35,7 +35,7 @@ private PromptVersion mapObject(JsonNode jsonNode) {
.template(jsonNode.get("template").asText())
.metadata(jsonNode.get("metadata"))
.changeDescription(jsonNode.get("change_description").asText())
.variables(MustacheVariableExtractor.extractVariables(jsonNode.get("template").asText()))
.variables(MustacheUtils.extractVariables(jsonNode.get("template").asText()))
.createdAt(Instant.from(FORMATTER.parse(jsonNode.get("created_at").asText())))
.createdBy(jsonNode.get("created_by").asText())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import com.github.mustachejava.codes.ValueCode;
import lombok.experimental.UtilityClass;

import java.io.IOException;
import java.io.StringReader;
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.io.Writer;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

@UtilityClass
public class MustacheVariableExtractor {
public class MustacheUtils {

public static final MustacheFactory MF = new DefaultMustacheFactory();

Expand All @@ -31,6 +36,18 @@ public static Set<String> extractVariables(String template) {
return variables;
}

public static String render(String template, Map<String, ?> context) {

Mustache mustache = MF.compile(new StringReader(template), "template");

try (Writer writer = mustache.execute(new StringWriter(), context)){
writer.flush();
return writer.toString();
} catch (IOException e) {
throw new UncheckedIOException("Failed to render template", e);
}
}

private static void collectVariables(Code[] codes, Set<String> variables) {
for (Code code : codes) {
if (Objects.requireNonNull(code) instanceof ValueCode valueCode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
@Slf4j
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
@DisplayName("LlmAsJudge Message Render")
public class OnlineScoringEngineTest {
class OnlineScoringEngineTest {
@Mock
AutomationRuleEvaluatorService ruleEvaluatorService;
@Mock
Expand Down

0 comments on commit c3a66b0

Please sign in to comment.