From b954ab226a7300c49ab286c23db1c1aaf1b9864d Mon Sep 17 00:00:00 2001 From: Gleb Sizov Date: Mon, 28 Oct 2024 16:02:02 +0100 Subject: [PATCH 01/12] wip to support generate indexing statement --- .../indexinglanguage/ScriptParserContext.java | 8 +- .../expressions/Expression.java | 7 +- .../expressions/GenerateExpression.java | 269 ++++++++++++++++++ .../expressions/ScriptExpression.java | 7 +- .../expressions/StatementExpression.java | 7 +- .../src/main/javacc/IndexingParser.jj | 18 ++ .../GeneratorScriptTestCase.java | 19 ++ .../GeneratorScriptTester.java | 63 ++++ .../ScriptParserTestCase.java | 5 +- .../parser/DefaultFieldNameTestCase.java | 8 +- .../com/yahoo/language/process/Generator.java | 40 +++ model-integration/README | 2 +- 12 files changed, 443 insertions(+), 10 deletions(-) create mode 100644 indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java create mode 100644 indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java create mode 100644 indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java create mode 100644 linguistics/src/main/java/com/yahoo/language/process/Generator.java diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java index 01c688af8e33..438eb8f15c43 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java @@ -3,6 +3,7 @@ import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.vespa.indexinglanguage.parser.CharStream; @@ -17,12 +18,14 @@ public class ScriptParserContext { private AnnotatorConfig annotatorConfig = new AnnotatorConfig(); private Linguistics linguistics; private final Map embedders; + private final Map generators; private String defaultFieldName = null; private CharStream inputStream = null; - public ScriptParserContext(Linguistics linguistics, Map embedders) { + public ScriptParserContext(Linguistics linguistics, Map embedders, Map generators) { this.linguistics = linguistics; this.embedders = embedders; + this.generators = generators; } public AnnotatorConfig getAnnotatorConfig() { @@ -46,6 +49,9 @@ public ScriptParserContext setLinguistics(Linguistics linguistics) { public Map getEmbedders() { return Collections.unmodifiableMap(embedders); } + + public Map getGenerators() { return Collections.unmodifiableMap(generators); + } public String getDefaultFieldName() { return defaultFieldName; diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java index 99d5e51c9eda..3975efd845ba 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java @@ -10,6 +10,7 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.*; import com.yahoo.vespa.indexinglanguage.parser.IndexingInput; @@ -273,7 +274,11 @@ public static Expression fromString(String expression) throws ParseException { } public static Expression fromString(String expression, Linguistics linguistics, Map embedders) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); + return newInstance(new ScriptParserContext(linguistics, embedders, Map.of()).setInputStream(new IndexingInput(expression))); + } + + public static Expression fromString(String expression, Linguistics linguistics, Map embedders, Map generators) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders, generators).setInputStream(new IndexingInput(expression))); } public static Expression newInstance(ScriptParserContext context) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java new file mode 100644 index 000000000000..2d6c896b477c --- /dev/null +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java @@ -0,0 +1,269 @@ +package com.yahoo.vespa.indexinglanguage.expressions; + +import com.yahoo.document.ArrayDataType; +import com.yahoo.document.DataType; +import com.yahoo.document.DocumentType; +import com.yahoo.document.Field; +import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.Array; +import com.yahoo.document.datatypes.StringFieldValue; +import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.language.Linguistics; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public class GenerateExpression extends Expression { + private final Linguistics linguistics; + private final Generator generator; + private final String generatorId; + private final List generatorArguments; + + /** The destination the embedding will be written to on the form [schema name].[field name] */ + private String destination; + + public GenerateExpression( + Linguistics linguistics, + Map generators, + String generatorId, + List generatorArguments + ) { + super(null); + this.linguistics = linguistics; + this.generatorId = generatorId; + this.generatorArguments = List.copyOf(generatorArguments); + + boolean generatorIdProvided = generatorId != null && !generatorId.isEmpty(); + + if (generators.isEmpty()) { + throw new IllegalStateException("No generators provided"); // should never happen + } + else if (generators.size() == 1 && ! generatorIdProvided) { + this.generator = generators.entrySet().stream().findFirst().get().getValue(); + } + else if (generators.size() > 1 && ! generatorIdProvided) { + this.generator = new Generator.FailingGenerator( + "Multiple generators are provided but no generator id is given. " + + "Valid generators are " + validGenerators(generators)); + } + else if ( ! embedders.containsKey(embedderId)) { + this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + } else { + this.embedder = embedders.get(embedderId); + } + } + + @Override + public DataType setInputType(DataType type, VerificationContext context) { + super.setInputType(type, context); + // TODO: Activate type checking + // if ( ! (type == DataType.STRING) + // && ! (type instanceof ArrayDataType array && array.getNestedType() == DataType.STRING)) + // throw new IllegalArgumentException("embed request either a string or array input type, but got " + type); + return null; // embed cannot determine the output type from the input + } + + @Override + public DataType setOutputType(DataType type, VerificationContext context) { + super.setOutputType(type, TensorDataType.any(), context); + return getInputType(context); // the input (string vs. array of string) cannot be determined from the output + } + + @Override + public void setStatementOutput(DocumentType documentType, Field field) { + targetType = toTargetTensor(field.getDataType()); + destination = documentType.getName() + "." + field.getName(); + } + + @Override + protected void doVerify(VerificationContext context) { + targetType = toTargetTensor(getOutputType(context)); + if ( ! validTarget(targetType)) + throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor, a mapped 2d tensor, " + + "an array of dense 1d tensors, or a mixed 2d or 3d tensor"); + if (targetType.rank() == 2 && targetType.mappedSubtype().rank() == 2) { + if (embedderArguments.size() != 1) + throw new VerificationException(this, "When the embedding target field is a 2d mapped tensor " + + "the name of the tensor dimension that corresponds to the input array elements must " + + "be given as a second argument to embed, e.g: ... | embed splade paragraph | ..."); + if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0))) { + throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " + + "is not a sparse dimension of the target type " + targetType); + + } + } + if (targetType.rank() == 3) { + if (embedderArguments.size() != 1) + throw new VerificationException(this, "When the embedding target field is a 3d tensor " + + "the name of the tensor dimension that corresponds to the input array elements must " + + "be given as a second argument to embed, e.g: ... | embed colbert paragraph | ..."); + if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0))) + throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " + + "is not a sparse dimension of the target type " + targetType); + } + context.setCurrentType(createdOutputType()); + } + + @Override + protected void doExecute(ExecutionContext context) { + if (context.getCurrentValue() == null) return; + Tensor output; + if (context.getCurrentValue().getDataType() == DataType.STRING) { + output = embedSingleValue(context); + } + else if (context.getCurrentValue().getDataType() instanceof ArrayDataType arrayType + && arrayType.getNestedType() == DataType.STRING) { + output = embedArrayValue(context); + } + else { + throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " + + context.getCurrentValue().getDataType()); + } + context.setCurrentValue(new TensorFieldValue(output)); + } + + private Tensor embedSingleValue(ExecutionContext context) { + StringFieldValue input = (StringFieldValue)context.getCurrentValue(); + return embed(input.getString(), targetType, context); + } + + @SuppressWarnings("unchecked") + private Tensor embedArrayValue(ExecutionContext context) { + var input = (Array)context.getCurrentValue(); + var builder = Tensor.Builder.of(targetType); + if (targetType.rank() == 2) + if (targetType.indexedSubtype().rank() == 1) + embedArrayValueToRank2Tensor(input, builder, context); + else if(targetType.mappedSubtype().rank() == 2) + embedArrayValueToRank2MappedTensor(input, builder, context); + else + throw new IllegalArgumentException("Embedding an array into " + targetType + " is not supported"); + else + embedArrayValueToRank3Tensor(input, builder, context); + return builder.build(); + } + + private void embedArrayValueToRank2Tensor(Array input, + Tensor.Builder builder, + ExecutionContext context) { + String mappedDimension = targetType.mappedSubtype().dimensions().get(0).name(); + String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name(); + for (int i = 0; i < input.size(); i++) { + Tensor tensor = embed(input.get(i).getString(), targetType.indexedSubtype(), context); + for (Iterator cells = tensor.cellIterator(); cells.hasNext(); ) { + Tensor.Cell cell = cells.next(); + builder.cell() + .label(mappedDimension, i) + .label(indexedDimension, cell.getKey().numericLabel(0)) + .value(cell.getValue()); + } + } + } + + private void embedArrayValueToRank3Tensor(Array input, + Tensor.Builder builder, + ExecutionContext context) { + String outerMappedDimension = embedderArguments.get(0); + String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get(); + String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name(); + long indexedDimensionSize = targetType.indexedSubtype().dimensions().get(0).size().get(); + var innerType = new TensorType.Builder(targetType.valueType()).mapped(innerMappedDimension).indexed(indexedDimension,indexedDimensionSize).build(); + int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension); + int indexedDimensionIndex = innerType.indexOfDimensionAsInt(indexedDimension); + for (int i = 0; i < input.size(); i++) { + Tensor tensor = embed(input.get(i).getString(), innerType, context); + for (Iterator cells = tensor.cellIterator(); cells.hasNext(); ) { + Tensor.Cell cell = cells.next(); + builder.cell() + .label(outerMappedDimension, i) + .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)) + .label(indexedDimension, cell.getKey().numericLabel(indexedDimensionIndex)) + .value(cell.getValue()); + } + } + } + + private void embedArrayValueToRank2MappedTensor(Array input, + Tensor.Builder builder, + ExecutionContext context) { + String outerMappedDimension = embedderArguments.get(0); + String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get(); + + var innerType = new TensorType.Builder(targetType.valueType()).mapped(innerMappedDimension).build(); + int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension); + + for (int i = 0; i < input.size(); i++) { + Tensor tensor = embed(input.get(i).getString(), innerType, context); + for (Iterator cells = tensor.cellIterator(); cells.hasNext(); ) { + Tensor.Cell cell = cells.next(); + builder.cell() + .label(outerMappedDimension, i) + .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)) + .value(cell.getValue()); + } + } + } + + private Tensor embed(String input, TensorType targetType, ExecutionContext context) { + return embedder.embed(input, + new Embedder.Context(destination, context.getCache()).setLanguage(context.resolveLanguage(linguistics)) + .setEmbedderId(embedderId), + targetType); + } + + @Override + public DataType createdOutputType() { + return new TensorDataType(targetType); + } + + private static TensorType toTargetTensor(DataType dataType) { + if (dataType instanceof ArrayDataType) return toTargetTensor(dataType.getNestedType()); + if ( ! ( dataType instanceof TensorDataType)) + throw new IllegalArgumentException("Expected a tensor data type but got " + dataType); + return ((TensorDataType)dataType).getTensorType(); + } + + private boolean validTarget(TensorType target) { + if (target.rank() == 1) // indexed or mapped 1d tensor + return true; + if (target.rank() == 2 && target.indexedSubtype().rank() == 1) + return true; // mixed 2d tensor + if(target.rank() == 2 && target.mappedSubtype().rank() == 2) + return true; // mapped 2d tensor + if (target.rank() == 3 && target.indexedSubtype().rank() == 1) + return true; // mixed 3d tensor + return false; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("embed"); + if (this.embedderId != null && !this.embedderId.isEmpty()) + sb.append(" ").append(this.embedderId); + embedderArguments.forEach(arg -> sb.append(" ").append(arg)); + return sb.toString(); + } + + @Override + public int hashCode() { return EmbedExpression.class.hashCode(); } + + @Override + public boolean equals(Object o) { + return o instanceof EmbedExpression; + } + + private static String validEmbedders(Map embedders) { + List embedderIds = new ArrayList<>(); + embedders.forEach((key, value) -> embedderIds.add(key)); + embedderIds.sort(null); + return String.join(", ", embedderIds); + } +} diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java index bf241cf61784..47031dc71bbd 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java @@ -5,6 +5,7 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.ExpressionConverter; import com.yahoo.vespa.indexinglanguage.ScriptParser; @@ -116,7 +117,11 @@ public static ScriptExpression fromString(String expression) throws ParseExcepti } public static ScriptExpression fromString(String expression, Linguistics linguistics, Map embedders) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); + return newInstance(new ScriptParserContext(linguistics, embedders, Map.of()).setInputStream(new IndexingInput(expression))); + } + + public static Expression fromString(String expression, Linguistics linguistics, Map embedders, Map generators) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders, generators).setInputStream(new IndexingInput(expression))); } public static ScriptExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java index e8ed41df7b9c..d9eade7628a3 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java @@ -4,6 +4,7 @@ import com.yahoo.document.DataType; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.ExpressionConverter; import com.yahoo.vespa.indexinglanguage.ScriptParser; @@ -137,7 +138,11 @@ public static StatementExpression fromString(String expression) throws ParseExce } public static StatementExpression fromString(String expression, Linguistics linguistics, Map embedders) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); + return newInstance(new ScriptParserContext(linguistics, embedders, Map.of()).setInputStream(new IndexingInput(expression))); + } + + public static StatementExpression fromString(String expression, Linguistics linguistics, Map embedders, Map generators) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders, generators).setInputStream(new IndexingInput(expression))); } public static StatementExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index fbf460ae6f6c..2d0a8c1da468 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -33,6 +33,7 @@ import com.yahoo.text.StringUtilities; import com.yahoo.vespa.indexinglanguage.expressions.*; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.Linguistics; /** @@ -43,6 +44,7 @@ public class IndexingParser { private String defaultFieldName; private Linguistics linguistics; private Map embedders; + private Map generators; private AnnotatorConfig annotatorCfg; public IndexingParser(String str) { @@ -158,6 +160,7 @@ TOKEN : | | | + | | | | @@ -306,6 +309,7 @@ Expression value() : val = clearStateExp() | val = echoExp() | val = embedExp() | + val = generateExp() | val = exactExp() | val = executionValueExp() | val = flattenExp() | @@ -412,6 +416,20 @@ Expression embedExp() : { return new EmbedExpression(linguistics, embedders, embedderId, embedderArguments); } } +Expression generateExp() : +{ + String generatorId = ""; + String generatorArgument; + List generatorArguments = new ArrayList(); +} +{ + ( + [ LOOKAHEAD(2) generatorId = identifier() ] + ( LOOKAHEAD(2) generatorArgument = identifier() { generatorArguments.add(generatorArgument); } )* + ) + { return new GenerateExpression(linguistics, generators, generatorId, generatorArguments); } +} + Expression exactExp() : { int maxTokenLength = annotatorCfg.getMaxTokenLength(); diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java new file mode 100644 index 000000000000..ef2d39bdaa70 --- /dev/null +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java @@ -0,0 +1,19 @@ +package com.yahoo.vespa.indexinglanguage; + +import org.junit.Test; + +import java.util.Map; + +public class GeneratorScriptTestCase { + + @Test + public void testGenerate() { + // No embedders - parsing only + var tester = new GeneratorScriptTester( + Map.of("gen1", new GeneratorScriptTester.RepeatMockGenerator())); + tester.testStatement("input myText | generate | index", "hello", "hello hello"); + tester.testStatement("input myText | generate gen1 | index", "hello", "hello hello"); + tester.testStatement("input myText | generate 'gen1' | 'index'", "hello", "hello hello"); + tester.testStatement("input myText | generate 'gen1' | 'index'", null, null); + } +} \ No newline at end of file diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java new file mode 100644 index 000000000000..388f6ebe2853 --- /dev/null +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java @@ -0,0 +1,63 @@ +package com.yahoo.vespa.indexinglanguage; + +import com.yahoo.language.process.Generator; +import com.yahoo.language.simple.SimpleLinguistics; +import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import com.yahoo.vespa.indexinglanguage.parser.ParseException; +import com.yahoo.document.DataType; +import com.yahoo.document.DocumentType; +import com.yahoo.document.Field; +import com.yahoo.document.datatypes.StringFieldValue; +import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class GeneratorScriptTester { + private final Map generators; + + public GeneratorScriptTester(Map generators) { + this.generators = generators; + } + + public static class RepeatMockGenerator implements Generator { + public String generate(String prompt) { + return prompt + " " + prompt; + } + } + + public void testStatement(String expressionString, String input, String expected) { + var expression = expressionFrom(expressionString); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myText", DataType.STRING)); + + if (input != null) + adapter.setValue("myText", new StringFieldValue(input)); + + expression.setStatementOutput(new DocumentType("myDocument"), new Field("myText", DataType.STRING)); + ExecutionContext context = new ExecutionContext(adapter); + expression.execute(context); + + if (input == null) { + assertFalse(adapter.values.containsKey("myText")); + } + else { + assertTrue(adapter.values.containsKey("myText")); + assertEquals(expected, ((StringFieldValue)adapter.values.get("myText")).getString()); + } + } + + private Expression expressionFrom(String string) { + try { + return Expression.fromString(string, new SimpleLinguistics(), Map.of(), generators); + } + catch (ParseException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java index ac95c72a64bd..9f0e617dd40b 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.indexinglanguage; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.expressions.EchoExpression; import com.yahoo.vespa.indexinglanguage.expressions.InputExpression; @@ -96,7 +97,9 @@ private static void assertException(ParseException e, String expectedMessage) th } private static ScriptParserContext newContext(String input) { - return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse.asMap()).setInputStream(new IndexingInput(input)); + return new ScriptParserContext( + new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap() + ).setInputStream(new IndexingInput(input)); } } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java index 7a92d51fda39..e13db73f7d7d 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.indexinglanguage.parser; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.ScriptParserContext; import com.yahoo.vespa.indexinglanguage.expressions.Expression; @@ -18,10 +19,9 @@ public class DefaultFieldNameTestCase { @Test public void requireThatDefaultFieldNameIsAppliedWhenArgumentIsMissing() throws ParseException { IndexingInput input = new IndexingInput("input"); - InputExpression exp = (InputExpression)Expression.newInstance(new ScriptParserContext(new SimpleLinguistics(), - Embedder.throwsOnUse.asMap()) - .setInputStream(input) - .setDefaultFieldName("foo")); + InputExpression exp = (InputExpression) Expression.newInstance(new ScriptParserContext( + new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap() + ).setInputStream(input).setDefaultFieldName("foo")); assertEquals("foo", exp.getFieldName()); } diff --git a/linguistics/src/main/java/com/yahoo/language/process/Generator.java b/linguistics/src/main/java/com/yahoo/language/process/Generator.java new file mode 100644 index 000000000000..35ad791c3a7d --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/process/Generator.java @@ -0,0 +1,40 @@ +package com.yahoo.language.process; + +import java.util.Map; + +public interface Generator { + + // Name of generator when none is explicitly given + String defaultGeneratorId = "default"; + + // An instance of this which throws IllegalStateException if attempted used + Generator throwsOnUse = new FailingGenerator(); + + // Returns this generator instance as a map with the default generator name + default Map asMap() { + return asMap(defaultGeneratorId); + } + + // Returns this generator instance as a map with the given name + default Map asMap(String name) { + return Map.of(name, this); + } + + String generate(String prompt); + + class FailingGenerator implements Generator { + private final String message; + + public FailingGenerator() { + this("No generator has been configured"); + } + + public FailingGenerator(String message) { + this.message = message; + } + + public String generate(String prompt) { + throw new IllegalStateException(message); + } + } +} diff --git a/model-integration/README b/model-integration/README index a58d88dc3112..ef9e367d13c4 100644 --- a/model-integration/README +++ b/model-integration/README @@ -1,4 +1,4 @@ -3rd party ML models and converters from these to ranking expresssions, provided as a separate bundle. +3rd party ML models and converters from these to ranking expressions, provided as a separate bundle. This has two purposes - Make converters (importers) available to config models while loading them in just a single instance even when From 9461c66141c007b3ffd83bdb178d219c7d370376 Mon Sep 17 00:00:00 2001 From: Gleb Sizov Date: Tue, 29 Oct 2024 14:15:52 +0100 Subject: [PATCH 02/12] Initial support for generate in indexing language --- .../vespa/indexinglanguage/ScriptParser.java | 2 + .../expressions/GenerateExpression.java | 223 +++++------------- .../src/main/javacc/IndexingParser.jj | 1 + .../EmbeddingScriptTester.java | 28 +++ .../GeneratorScriptTestCase.java | 38 ++- .../GeneratorScriptTester.java | 59 ++++- linguistics/pom.xml | 6 + .../com/yahoo/language/process/Generator.java | 92 +++++++- 8 files changed, 261 insertions(+), 188 deletions(-) diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java index f243f854c299..db14014b4a83 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java @@ -47,6 +47,8 @@ private static T parse(ScriptParserContext context, Parse parser.setDefaultFieldName(context.getDefaultFieldName()); parser.setLinguistics(context.getLinguistcs()); parser.setEmbedders(context.getEmbedders()); + parser.setGenerators(context.getGenerators()); + try { return method.call(parser); } catch (ParseException e) { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java index 2d6c896b477c..46b7d0f3c9ca 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java @@ -1,21 +1,13 @@ package com.yahoo.vespa.indexinglanguage.expressions; -import com.yahoo.document.ArrayDataType; import com.yahoo.document.DataType; import com.yahoo.document.DocumentType; import com.yahoo.document.Field; -import com.yahoo.document.TensorDataType; -import com.yahoo.document.datatypes.Array; import com.yahoo.document.datatypes.StringFieldValue; -import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.language.Linguistics; -import com.yahoo.language.process.Embedder; import com.yahoo.language.process.Generator; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; import java.util.Map; @@ -25,9 +17,12 @@ public class GenerateExpression extends Expression { private final String generatorId; private final List generatorArguments; - /** The destination the embedding will be written to on the form [schema name].[field name] */ + /** The destination the generated value will be written to in the form [schema name].[field name] */ private String destination; + /** The target type we are generating into. */ + private DataType targetType; + public GenerateExpression( Linguistics linguistics, Map generators, @@ -52,218 +47,114 @@ else if (generators.size() > 1 && ! generatorIdProvided) { "Multiple generators are provided but no generator id is given. " + "Valid generators are " + validGenerators(generators)); } - else if ( ! embedders.containsKey(embedderId)) { - this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); + else if ( ! generators.containsKey(generatorId)) { + this.generator = new Generator.FailingGenerator("Can't find generator '" + generatorId + "'. " + + "Valid generators are " + validGenerators(generators)); } else { - this.embedder = embedders.get(embedderId); + this.generator = generators.get(generatorId); } } @Override public DataType setInputType(DataType type, VerificationContext context) { + // TODO: Not sure if this implementation of the methods is correct, needs careful review. super.setInputType(type, context); - // TODO: Activate type checking - // if ( ! (type == DataType.STRING) - // && ! (type instanceof ArrayDataType array && array.getNestedType() == DataType.STRING)) - // throw new IllegalArgumentException("embed request either a string or array input type, but got " + type); - return null; // embed cannot determine the output type from the input + + if (type == DataType.STRING) + throw new IllegalArgumentException("generate requires a string input type, but got " + type); + + return DataType.STRING; } @Override public DataType setOutputType(DataType type, VerificationContext context) { - super.setOutputType(type, TensorDataType.any(), context); - return getInputType(context); // the input (string vs. array of string) cannot be determined from the output + // TODO: Not sure if this implementation of the methods is correct, needs careful review. + super.setOutputType(type, type, context); + + if (type != DataType.STRING) + throw new IllegalArgumentException("generate requires a string input type, but got " + type); + + return DataType.STRING; } @Override public void setStatementOutput(DocumentType documentType, Field field) { - targetType = toTargetTensor(field.getDataType()); + targetType = field.getDataType(); destination = documentType.getName() + "." + field.getName(); } @Override protected void doVerify(VerificationContext context) { - targetType = toTargetTensor(getOutputType(context)); - if ( ! validTarget(targetType)) - throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor, a mapped 2d tensor, " + - "an array of dense 1d tensors, or a mixed 2d or 3d tensor"); - if (targetType.rank() == 2 && targetType.mappedSubtype().rank() == 2) { - if (embedderArguments.size() != 1) - throw new VerificationException(this, "When the embedding target field is a 2d mapped tensor " + - "the name of the tensor dimension that corresponds to the input array elements must " + - "be given as a second argument to embed, e.g: ... | embed splade paragraph | ..."); - if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0))) { - throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " + - "is not a sparse dimension of the target type " + targetType); - - } - } - if (targetType.rank() == 3) { - if (embedderArguments.size() != 1) - throw new VerificationException(this, "When the embedding target field is a 3d tensor " + - "the name of the tensor dimension that corresponds to the input array elements must " + - "be given as a second argument to embed, e.g: ... | embed colbert paragraph | ..."); - if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0))) - throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " + - "is not a sparse dimension of the target type " + targetType); - } + targetType = getOutputType(context); + + if (!validTarget(targetType)) + throw new VerificationException(this, "The generate target field must be a String"); + context.setCurrentType(createdOutputType()); } @Override protected void doExecute(ExecutionContext context) { if (context.getCurrentValue() == null) return; - Tensor output; + + String output; if (context.getCurrentValue().getDataType() == DataType.STRING) { - output = embedSingleValue(context); - } - else if (context.getCurrentValue().getDataType() instanceof ArrayDataType arrayType - && arrayType.getNestedType() == DataType.STRING) { - output = embedArrayValue(context); + output = generateSingleValue(context); } else { - throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " + + throw new IllegalArgumentException("Generate can only be done on string fields, not " + context.getCurrentValue().getDataType()); } - context.setCurrentValue(new TensorFieldValue(output)); + + context.setCurrentValue(new StringFieldValue(output)); } - private Tensor embedSingleValue(ExecutionContext context) { + private String generateSingleValue(ExecutionContext context) { StringFieldValue input = (StringFieldValue)context.getCurrentValue(); - return embed(input.getString(), targetType, context); - } - - @SuppressWarnings("unchecked") - private Tensor embedArrayValue(ExecutionContext context) { - var input = (Array)context.getCurrentValue(); - var builder = Tensor.Builder.of(targetType); - if (targetType.rank() == 2) - if (targetType.indexedSubtype().rank() == 1) - embedArrayValueToRank2Tensor(input, builder, context); - else if(targetType.mappedSubtype().rank() == 2) - embedArrayValueToRank2MappedTensor(input, builder, context); - else - throw new IllegalArgumentException("Embedding an array into " + targetType + " is not supported"); - else - embedArrayValueToRank3Tensor(input, builder, context); - return builder.build(); - } - - private void embedArrayValueToRank2Tensor(Array input, - Tensor.Builder builder, - ExecutionContext context) { - String mappedDimension = targetType.mappedSubtype().dimensions().get(0).name(); - String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name(); - for (int i = 0; i < input.size(); i++) { - Tensor tensor = embed(input.get(i).getString(), targetType.indexedSubtype(), context); - for (Iterator cells = tensor.cellIterator(); cells.hasNext(); ) { - Tensor.Cell cell = cells.next(); - builder.cell() - .label(mappedDimension, i) - .label(indexedDimension, cell.getKey().numericLabel(0)) - .value(cell.getValue()); - } - } + return generate(input.getString(), targetType, context); } - private void embedArrayValueToRank3Tensor(Array input, - Tensor.Builder builder, - ExecutionContext context) { - String outerMappedDimension = embedderArguments.get(0); - String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get(); - String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name(); - long indexedDimensionSize = targetType.indexedSubtype().dimensions().get(0).size().get(); - var innerType = new TensorType.Builder(targetType.valueType()).mapped(innerMappedDimension).indexed(indexedDimension,indexedDimensionSize).build(); - int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension); - int indexedDimensionIndex = innerType.indexOfDimensionAsInt(indexedDimension); - for (int i = 0; i < input.size(); i++) { - Tensor tensor = embed(input.get(i).getString(), innerType, context); - for (Iterator cells = tensor.cellIterator(); cells.hasNext(); ) { - Tensor.Cell cell = cells.next(); - builder.cell() - .label(outerMappedDimension, i) - .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)) - .label(indexedDimension, cell.getKey().numericLabel(indexedDimensionIndex)) - .value(cell.getValue()); - } - } - } - - private void embedArrayValueToRank2MappedTensor(Array input, - Tensor.Builder builder, - ExecutionContext context) { - String outerMappedDimension = embedderArguments.get(0); - String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get(); - - var innerType = new TensorType.Builder(targetType.valueType()).mapped(innerMappedDimension).build(); - int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension); - - for (int i = 0; i < input.size(); i++) { - Tensor tensor = embed(input.get(i).getString(), innerType, context); - for (Iterator cells = tensor.cellIterator(); cells.hasNext(); ) { - Tensor.Cell cell = cells.next(); - builder.cell() - .label(outerMappedDimension, i) - .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex)) - .value(cell.getValue()); - } - } - } - - private Tensor embed(String input, TensorType targetType, ExecutionContext context) { - return embedder.embed(input, - new Embedder.Context(destination, context.getCache()).setLanguage(context.resolveLanguage(linguistics)) - .setEmbedderId(embedderId), - targetType); + private String generate(String input, DataType targetType, ExecutionContext context) { + return generator.generate( + input, + new Generator.Context(destination, context.getCache()) + .setLanguage(context.resolveLanguage(linguistics)) + .setGeneratorId(generatorId), + targetType + ); } @Override public DataType createdOutputType() { - return new TensorDataType(targetType); - } - - private static TensorType toTargetTensor(DataType dataType) { - if (dataType instanceof ArrayDataType) return toTargetTensor(dataType.getNestedType()); - if ( ! ( dataType instanceof TensorDataType)) - throw new IllegalArgumentException("Expected a tensor data type but got " + dataType); - return ((TensorDataType)dataType).getTensorType(); + return targetType; } - private boolean validTarget(TensorType target) { - if (target.rank() == 1) // indexed or mapped 1d tensor - return true; - if (target.rank() == 2 && target.indexedSubtype().rank() == 1) - return true; // mixed 2d tensor - if(target.rank() == 2 && target.mappedSubtype().rank() == 2) - return true; // mapped 2d tensor - if (target.rank() == 3 && target.indexedSubtype().rank() == 1) - return true; // mixed 3d tensor - return false; + private boolean validTarget(DataType target) { + return target == DataType.STRING; } @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append("embed"); - if (this.embedderId != null && !this.embedderId.isEmpty()) - sb.append(" ").append(this.embedderId); - embedderArguments.forEach(arg -> sb.append(" ").append(arg)); + sb.append("generate"); + if (this.generatorId != null && !this.generatorId.isEmpty()) + sb.append(" ").append(this.generatorId); + generatorArguments.forEach(arg -> sb.append(" ").append(arg)); return sb.toString(); } @Override - public int hashCode() { return EmbedExpression.class.hashCode(); } + public int hashCode() { return GenerateExpression.class.hashCode(); } @Override public boolean equals(Object o) { - return o instanceof EmbedExpression; + return o instanceof GenerateExpression; } - private static String validEmbedders(Map embedders) { - List embedderIds = new ArrayList<>(); - embedders.forEach((key, value) -> embedderIds.add(key)); - embedderIds.sort(null); - return String.join(", ", embedderIds); + private static String validGenerators(Map generators) { + List generatorIds = new ArrayList<>(); + generators.forEach((key, value) -> generatorIds.add(key)); + generatorIds.sort(null); + return String.join(", ", generatorIds); } } diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index 2d0a8c1da468..6e761d5727dd 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -852,6 +852,7 @@ String identifier() : | | | + | | | | diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java index c4a53c1af683..e4171e60a709 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java @@ -63,6 +63,34 @@ public void testStatement(String expressionString, String input, String targetTe } } + public void testStatement2(String expressionString, String input, String targetTensorType, String expected) { + var expression = expressionFrom(expressionString); + TensorType tensorType = TensorType.fromSpec(targetTensorType); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myText", DataType.STRING)); + var tensorField = new Field("myTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + if (input != null) + adapter.setValue("myText", new StringFieldValue(input)); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + + ExecutionContext context = new ExecutionContext(adapter); + expression.execute(context); + if (input == null) { + assertFalse(adapter.values.containsKey("myTensor")); + } + else { + assertTrue(adapter.values.containsKey("myTensor")); + assertEquals(Tensor.from(tensorType, expected), + ((TensorFieldValue) adapter.values.get("myTensor")).getTensor().get()); + } + } + public void testStatementThrows(String expressionString, String input, String expectedMessage) { try { testStatement(expressionString, input, null); diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java index ef2d39bdaa70..1c733d79fdd9 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java @@ -1,5 +1,6 @@ package com.yahoo.vespa.indexinglanguage; +import com.yahoo.language.process.Generator; import org.junit.Test; import java.util.Map; @@ -8,12 +9,33 @@ public class GeneratorScriptTestCase { @Test public void testGenerate() { - // No embedders - parsing only - var tester = new GeneratorScriptTester( - Map.of("gen1", new GeneratorScriptTester.RepeatMockGenerator())); - tester.testStatement("input myText | generate | index", "hello", "hello hello"); - tester.testStatement("input myText | generate gen1 | index", "hello", "hello hello"); - tester.testStatement("input myText | generate 'gen1' | 'index'", "hello", "hello hello"); - tester.testStatement("input myText | generate 'gen1' | 'index'", null, null); - } + // No generators - parsing only + var tester = new GeneratorScriptTester(Generator.throwsOnUse.asMap()); + tester.expressionFrom("input myText | generate | attribute 'myGeneratedText'"); + + // One generator + tester = new GeneratorScriptTester(Map.of( + "gen1", new GeneratorScriptTester.RepeatMockGenerator("myDocument.myGeneratedText"))); + tester.testStatement("input myText | generate | attribute myGeneratedText", + "hello", "hello hello"); + tester.testStatement("input myText | generate gen1 | attribute myGeneratedText", + "hello", "hello hello"); + tester.testStatement("input myText | generate 'gen1' | attribute 'myGeneratedText'", + "hello", "hello hello"); + tester.testStatement("input myText | generate 'gen1' | attribute myGeneratedText", + null, null); + + // Two generators + tester = new GeneratorScriptTester(Map.of( + "gen1", new GeneratorScriptTester.RepeatMockGenerator("myDocument.myGeneratedText", 2), + "gen2", new GeneratorScriptTester.RepeatMockGenerator("myDocument.myGeneratedText", 3))); + tester.testStatement("input myText | generate gen1 | attribute myGeneratedText", + "hello", "hello hello"); + tester.testStatement("input myText | generate gen2 | attribute myGeneratedText", + "hello", "hello hello hello"); + tester.testStatementThrows("input myText | generate | attribute myGeneratedText", + "hello", "Multiple generators are provided but no generator id is given. Valid generators are gen1, gen2"); + tester.testStatementThrows("input myText | generate gen3 | attribute myGeneratedText", + "hello", "Can't find generator 'gen3'. Valid generators are gen1, gen2"); + } } \ No newline at end of file diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java index 388f6ebe2853..adfc94d8788e 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java @@ -15,6 +15,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class GeneratorScriptTester { private final Map generators; @@ -22,36 +23,42 @@ public class GeneratorScriptTester { public GeneratorScriptTester(Map generators) { this.generators = generators; } - - public static class RepeatMockGenerator implements Generator { - public String generate(String prompt) { - return prompt + " " + prompt; - } - } public void testStatement(String expressionString, String input, String expected) { var expression = expressionFrom(expressionString); SimpleTestAdapter adapter = new SimpleTestAdapter(); adapter.createField(new Field("myText", DataType.STRING)); + var generatedField = new Field("myGeneratedText", DataType.STRING); + adapter.createField(generatedField); if (input != null) adapter.setValue("myText", new StringFieldValue(input)); - - expression.setStatementOutput(new DocumentType("myDocument"), new Field("myText", DataType.STRING)); + + expression.setStatementOutput(new DocumentType("myDocument"), generatedField); + ExecutionContext context = new ExecutionContext(adapter); expression.execute(context); if (input == null) { - assertFalse(adapter.values.containsKey("myText")); + assertFalse(adapter.values.containsKey("myGeneratedText")); } else { - assertTrue(adapter.values.containsKey("myText")); - assertEquals(expected, ((StringFieldValue)adapter.values.get("myText")).getString()); + assertTrue(adapter.values.containsKey("myGeneratedText")); + assertEquals(expected, ((StringFieldValue)adapter.values.get("myGeneratedText")).getString()); + } + } + + public void testStatementThrows(String expressionString, String input, String expectedMessage) { + try { + testStatement(expressionString, input, null); + fail(); + } catch (IllegalStateException e) { + assertEquals(expectedMessage, e.getMessage()); } } - private Expression expressionFrom(String string) { + public Expression expressionFrom(String string) { try { return Expression.fromString(string, new SimpleLinguistics(), Map.of(), generators); } @@ -60,4 +67,32 @@ private Expression expressionFrom(String string) { } } + public static class RepeatMockGenerator implements Generator { + final String expectedDestination; + final int repetitions; + + public RepeatMockGenerator(String expectedDestination) { + this(expectedDestination, 2); + } + + public RepeatMockGenerator(String expectedDestination, int repetitions) { + this.expectedDestination = expectedDestination; + this.repetitions = repetitions; + } + + public String generate(String prompt, Context context, DataType dataType) { + var stringBuilder = new StringBuilder(); + + for (int i = 0; i < repetitions; i++) { + stringBuilder.append(prompt); + stringBuilder.append(" "); + } + + return stringBuilder.toString().trim(); + } + + void verifyDestination(Generator.Context context) { + assertEquals(expectedDestination, context.getDestination()); + } + } } diff --git a/linguistics/pom.xml b/linguistics/pom.xml index 40767bfff268..a23b2ce8e9bc 100644 --- a/linguistics/pom.xml +++ b/linguistics/pom.xml @@ -89,6 +89,12 @@ mockito-core test + + com.yahoo.vespa + document + 8.424.11 + compile + diff --git a/linguistics/src/main/java/com/yahoo/language/process/Generator.java b/linguistics/src/main/java/com/yahoo/language/process/Generator.java index 35ad791c3a7d..e1fe84a6b437 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Generator.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Generator.java @@ -1,6 +1,12 @@ package com.yahoo.language.process; +import com.yahoo.collections.LazyMap; +import com.yahoo.document.DataType; +import com.yahoo.language.Language; + import java.util.Map; +import java.util.Objects; +import java.util.function.Supplier; public interface Generator { @@ -20,8 +26,90 @@ default Map asMap(String name) { return Map.of(name, this); } - String generate(String prompt); + String generate(String prompt, Context context, DataType dataType); + + + class Context { + private Language language = Language.UNKNOWN; + private String destination; + private String generatorId = "unknown"; + private final Map cache; + + public Context(String destination) { + this(destination, LazyMap.newHashMap()); + } + + /** + * @param destination the name of the recipient of the generated output + * @param cache a cache shared between all generate invocations for a single request + */ + public Context(String destination, Map cache) { + this.destination = destination; + this.cache = Objects.requireNonNull(cache); + } + + private Context(Context other) { + language = other.language; + destination = other.destination; + generatorId = other.generatorId; + this.cache = other.cache; + } + + public Generator.Context copy() { return new Context(this); } + + /** Returns the language of the text, or UNKNOWN (default) to use a language independent generation */ + public Language getLanguage() { return language; } + + /** Sets the language of the text, or UNKNOWN to use language independent generation */ + public Context setLanguage(Language language) { + this.language = language != null ? language : Language.UNKNOWN; + return this; + } + + /** + * Returns the name of the recipient of this tensor. + * This is either a query feature name + * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field"). + * This cannot be null. + */ + public String getDestination() { return destination; } + + /** + * Sets the name of the recipient of this tensor. + * This is either a query feature name + * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field"). + */ + public Context setDestination(String destination) { + this.destination = destination; + return this; + } + + /** Return the generator id or 'unknown' if not set */ + public String getGeneratorId() { return generatorId; } + /** Sets the generator id */ + public Context setGeneratorId(String generatorId) { + this.generatorId = generatorId; + return this; + } + + public void putCachedValue(Object key, Object value) { + cache.put(key, value); + } + + /** Returns a cached value, or null if not present. */ + public Object getCachedValue(Object key) { + return cache.get(key); + } + + /** Returns the cached value, or computes and caches it if not present. */ + @SuppressWarnings("unchecked") + public T computeCachedValueIfAbsent(Object key, Supplier supplier) { + return (T) cache.computeIfAbsent(key, __ -> supplier.get()); + } + + } + class FailingGenerator implements Generator { private final String message; @@ -33,7 +121,7 @@ public FailingGenerator(String message) { this.message = message; } - public String generate(String prompt) { + public String generate(String prompt, Context context, DataType dataType) { throw new IllegalStateException(message); } } From 98073fe13e40e775b1dd36a525e7ec0caf108220 Mon Sep 17 00:00:00 2001 From: Gleb Sizov Date: Tue, 5 Nov 2024 08:25:53 +0100 Subject: [PATCH 03/12] draft generate expression --- .../com/yahoo/schema/document/SDField.java | 8 +- .../fieldoperation/IndexingOperation.java | 12 +-- config-model/src/main/javacc/SchemaParser.jj | 7 +- .../expressions/ScriptExpression.java | 4 +- .../expressions/StatementExpression.java | 4 +- .../src/main/javacc/IndexingParser.jj | 5 ++ .../ccc/indexinglanguage/IndexingParser.ccc | 7 ++ .../java/ai/vespa/generative/Generator.java | 63 ++++++++++++++++ .../resources/configdefinitions/generator.def | 11 +++ .../ai/vespa/generative/GeneratorTest.java | 73 +++++++++++++++++++ 10 files changed, 182 insertions(+), 12 deletions(-) create mode 100644 model-integration/src/main/java/ai/vespa/generative/Generator.java create mode 100644 model-integration/src/main/resources/configdefinitions/generator.def create mode 100644 model-integration/src/test/java/ai/vespa/generative/GeneratorTest.java diff --git a/config-model/src/main/java/com/yahoo/schema/document/SDField.java b/config-model/src/main/java/com/yahoo/schema/document/SDField.java index 2483fa476676..12a00652eaec 100644 --- a/config-model/src/main/java/com/yahoo/schema/document/SDField.java +++ b/config-model/src/main/java/com/yahoo/schema/document/SDField.java @@ -13,6 +13,7 @@ import com.yahoo.documentmodel.TemporaryUnknownType; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.schema.Index; import com.yahoo.schema.Schema; @@ -399,12 +400,13 @@ public boolean hasSingleAttribute() { /** Parse an indexing expression which will use the simple linguistics implementation suitable for testing */ public void parseIndexingScript(String schemaName, String script) { - parseIndexingScript(schemaName, script, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + parseIndexingScript(schemaName, script, new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); } - public void parseIndexingScript(String schemaName, String script, Linguistics linguistics, Map embedders) { + public void parseIndexingScript(String schemaName, String script, Linguistics linguistics, + Map embedders, Map generators) { try { - ScriptParserContext config = new ScriptParserContext(linguistics, embedders); + ScriptParserContext config = new ScriptParserContext(linguistics, embedders, generators); config.setInputStream(new IndexingInput(script)); setIndexingScript(schemaName, ScriptExpression.newInstance(config)); } catch (ParseException e) { diff --git a/config-model/src/main/java/com/yahoo/schema/fieldoperation/IndexingOperation.java b/config-model/src/main/java/com/yahoo/schema/fieldoperation/IndexingOperation.java index 11065f040ea5..b5d8f4ed369b 100644 --- a/config-model/src/main/java/com/yahoo/schema/fieldoperation/IndexingOperation.java +++ b/config-model/src/main/java/com/yahoo/schema/fieldoperation/IndexingOperation.java @@ -3,6 +3,7 @@ import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.schema.document.SDField; import com.yahoo.schema.parser.ParseException; @@ -34,13 +35,14 @@ public void apply(String schemaName, SDField field) { /** Creates an indexing operation which will use the simple linguistics implementation suitable for testing */ public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine) throws ParseException { - return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), + Generator.throwsOnUse.asMap()); } - public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine, - Linguistics linguistics, Map embedders) - throws ParseException { - ScriptParserContext config = new ScriptParserContext(linguistics, embedders); + public static IndexingOperation fromStream( + SimpleCharStream input, boolean multiLine, Linguistics linguistics, Map embedders, + Map generators) throws ParseException { + ScriptParserContext config = new ScriptParserContext(linguistics, embedders, generators); config.setAnnotatorConfig(new AnnotatorConfig()); config.setInputStream(input); ScriptExpression exp; diff --git a/config-model/src/main/javacc/SchemaParser.jj b/config-model/src/main/javacc/SchemaParser.jj index c9eff88764f7..7403c5cfc0fe 100644 --- a/config-model/src/main/javacc/SchemaParser.jj +++ b/config-model/src/main/javacc/SchemaParser.jj @@ -17,6 +17,7 @@ import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.api.ModelContext; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.search.query.ranking.Diversity; import com.yahoo.schema.DistributableResource; @@ -91,13 +92,15 @@ public class SchemaParser { * @param multiline Whether or not to allow multi-line expressions. * @param linguistics What to use for tokenizing. */ - private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Map embedders) throws ParseException { + private IndexingOperation newIndexingOperation( + boolean multiline, Linguistics linguistics, Map embedders, + Map generators) throws ParseException { SimpleCharStream input = (SimpleCharStream)token_source.input_stream; if (token.next != null) { input.backup(token.next.image.length()); } try { - return IndexingOperation.fromStream(input, multiline, linguistics, embedders); + return IndexingOperation.fromStream(input, multiline, linguistics, embedders, generators); } finally { token.next = null; jj_ntk = -1; diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java index 47031dc71bbd..4bc5ba283faa 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java @@ -120,7 +120,9 @@ public static ScriptExpression fromString(String expression, Linguistics linguis return newInstance(new ScriptParserContext(linguistics, embedders, Map.of()).setInputStream(new IndexingInput(expression))); } - public static Expression fromString(String expression, Linguistics linguistics, Map embedders, Map generators) throws ParseException { + public static Expression fromString( + String expression, Linguistics linguistics, Map embedders, + Map generators) throws ParseException { return newInstance(new ScriptParserContext(linguistics, embedders, generators).setInputStream(new IndexingInput(expression))); } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java index d9eade7628a3..aaba69082141 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java @@ -141,7 +141,9 @@ public static StatementExpression fromString(String expression, Linguistics ling return newInstance(new ScriptParserContext(linguistics, embedders, Map.of()).setInputStream(new IndexingInput(expression))); } - public static StatementExpression fromString(String expression, Linguistics linguistics, Map embedders, Map generators) throws ParseException { + public static StatementExpression fromString( + String expression, Linguistics linguistics, Map embedders, + Map generators) throws ParseException { return newInstance(new ScriptParserContext(linguistics, embedders, generators).setInputStream(new IndexingInput(expression))); } diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index 6e761d5727dd..dd068981d2c2 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -66,6 +66,11 @@ public class IndexingParser { return this; } + public IndexingParser setGenerators(Map generators) { + this.generators = generators; + return this; + } + public IndexingParser setAnnotatorConfig(AnnotatorConfig cfg) { annotatorCfg = cfg; return this; diff --git a/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc b/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc index 9237ff5cd143..f818bf20a6b8 100644 --- a/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc +++ b/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc @@ -33,6 +33,7 @@ INJECT IndexingParser: import com.yahoo.vespa.indexinglanguage.expressions.*; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.language.process.Embedder; + import com.yahoo.language.process.Generator; import com.yahoo.language.Linguistics; { /** @@ -43,6 +44,7 @@ INJECT IndexingParser: private String defaultFieldName; private Linguistics linguistics; private Map embedders; + private Map generators; private AnnotatorConfig annotatorCfg; private PrintStream logger = new PrintStream( @@ -73,6 +75,11 @@ INJECT IndexingParser: this.embedders = embedders; return this; } + + public IndexingParser setGenerators(Map generators) { + this.generators = generators; + return this; + } public IndexingParser setAnnotatorConfig(AnnotatorConfig cfg) { annotatorCfg = cfg; diff --git a/model-integration/src/main/java/ai/vespa/generative/Generator.java b/model-integration/src/main/java/ai/vespa/generative/Generator.java new file mode 100644 index 000000000000..8c882516c103 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/generative/Generator.java @@ -0,0 +1,63 @@ +package ai.vespa.generative; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.ComponentId; +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.document.DataType; + +import java.util.logging.Logger; +import java.util.stream.Collectors; + +public class Generator extends AbstractComponent implements com.yahoo.language.process.Generator { + private static final Logger log = Logger.getLogger(Generator.class.getName()); + private final LanguageModel languageModel; + + @Inject + public Generator(GeneratorConfig config, ComponentRegistry languageModels) { + this.languageModel = findLanguageModel(config.providerId(), languageModels); + } + + private LanguageModel findLanguageModel(String providerId, ComponentRegistry languageModels) + throws IllegalArgumentException + { + if (languageModels.allComponents().isEmpty()) { + throw new IllegalArgumentException("No language models were found"); + } + + if (providerId == null || providerId.isEmpty()) { + var entry = languageModels.allComponentsById().entrySet().stream().findFirst(); + + if (entry.isEmpty()) { + throw new IllegalArgumentException("No language models were found"); // shouldn't happen given check above + } + + log.info("Language model provider was not found in config. " + + "Fallback to using first available language model: " + entry.get().getKey()); + + return entry.get().getValue(); + } + + final LanguageModel languageModel = languageModels.getComponent(providerId); + + if (languageModel == null) { + throw new IllegalArgumentException("No component with id '" + providerId + "' was found. " + + "Available LLM components are: " + languageModels.allComponentsById().keySet().stream() + .map(ComponentId::toString).collect(Collectors.joining(","))); + } + + return languageModel; + } + + @Override + public String generate(String prompt, Context context, DataType dataType) { + var options = new InferenceParameters(s -> ""); + var promptObj = StringPrompt.from(prompt); + var completions = languageModel.complete(promptObj, options); + var firstCompletion = completions.get(0); + return firstCompletion.text(); + } +} diff --git a/model-integration/src/main/resources/configdefinitions/generator.def b/model-integration/src/main/resources/configdefinitions/generator.def new file mode 100644 index 000000000000..133015d76d2d --- /dev/null +++ b/model-integration/src/main/resources/configdefinitions/generator.def @@ -0,0 +1,11 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.generative + +# The external LLM provider - the id of a LanguageModel component +providerId string default="" + +# The default prompt to use if not overridden in query +prompt string default="" + +# The default prompt template file to use if not overridden in query. Above prompt has precedence if it is set. +promptTemplate path optional \ No newline at end of file diff --git a/model-integration/src/test/java/ai/vespa/generative/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/generative/GeneratorTest.java new file mode 100644 index 000000000000..2ef5139e7849 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/generative/GeneratorTest.java @@ -0,0 +1,73 @@ +package ai.vespa.generative; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.component.ComponentId; +import com.yahoo.component.provider.ComponentRegistry; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +import com.yahoo.document.DataType; +import org.junit.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +public class GeneratorTest { + @Test + public void testGeneration() { + LanguageModel languageModel1 = new RepeatMockLanguageModel(1); + LanguageModel languageModel2 = new RepeatMockLanguageModel(2); + var languageModels = Map.of("mock1", languageModel1, "mock2", languageModel2); + + var config1 = new GeneratorConfig.Builder().providerId("mock1").build(); + var generator1 = createGenerator(config1, languageModels); + var context = new com.yahoo.language.process.Generator.Context("schema.indexing"); + var result1 = generator1.generate("hello", context, DataType.STRING); + assertEquals("hello", result1); + + var config2 = new GeneratorConfig.Builder().providerId("mock2").build(); + var generator2 = createGenerator(config2, Map.of("mock1", languageModel1, "mock2", languageModel2)); + var result2 = generator2.generate("hello", context, DataType.STRING); + assertEquals("hello hello", result2); + } + + private static Generator createGenerator(GeneratorConfig config, Map languageModels) { + ComponentRegistry models = new ComponentRegistry<>(); + languageModels.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new Generator(config, models); + } + + public static class RepeatMockLanguageModel implements LanguageModel { + private final int repetitions; + + public RepeatMockLanguageModel(int repetitions) { + this.repetitions = repetitions; + } + + @Override + public List complete(Prompt prompt, InferenceParameters params) { + var stringBuilder = new StringBuilder(); + + for (int i = 0; i < repetitions; i++) { + stringBuilder.append(prompt.asString()); + stringBuilder.append(" "); + } + + return List.of(Completion.from(stringBuilder.toString().trim())); + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, + InferenceParameters params, + Consumer consumer) { + throw new UnsupportedOperationException(); + } + } +} From 7a293db3eb8340c70b0923043d9d63ff62e9da4e Mon Sep 17 00:00:00 2001 From: Gleb Sizov Date: Tue, 5 Nov 2024 15:12:54 +0100 Subject: [PATCH 04/12] Refactoring. LocallLLM and OpenAI implement Generator. --- .../{Generator.java => GeneratorUtils.java} | 23 ++++++--------- .../generative/LanguageModelGenerator.java | 29 +++++++++++++++++++ .../java/ai/vespa/llm/clients/LocalLLM.java | 9 +++++- .../java/ai/vespa/llm/clients/OpenAI.java | 10 ++++++- ...rator.def => language-model-generator.def} | 0 ...t.java => LanguageModelGeneratorTest.java} | 10 +++---- 6 files changed, 60 insertions(+), 21 deletions(-) rename model-integration/src/main/java/ai/vespa/generative/{Generator.java => GeneratorUtils.java} (71%) create mode 100644 model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java rename model-integration/src/main/resources/configdefinitions/{generator.def => language-model-generator.def} (100%) rename model-integration/src/test/java/ai/vespa/generative/{GeneratorTest.java => LanguageModelGeneratorTest.java} (85%) diff --git a/model-integration/src/main/java/ai/vespa/generative/Generator.java b/model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java similarity index 71% rename from model-integration/src/main/java/ai/vespa/generative/Generator.java rename to model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java index 8c882516c103..33e22b2aaf55 100644 --- a/model-integration/src/main/java/ai/vespa/generative/Generator.java +++ b/model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java @@ -3,25 +3,19 @@ import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; import ai.vespa.llm.completion.StringPrompt; -import com.yahoo.component.AbstractComponent; import com.yahoo.component.ComponentId; -import com.yahoo.component.annotation.Inject; import com.yahoo.component.provider.ComponentRegistry; -import com.yahoo.document.DataType; import java.util.logging.Logger; import java.util.stream.Collectors; -public class Generator extends AbstractComponent implements com.yahoo.language.process.Generator { - private static final Logger log = Logger.getLogger(Generator.class.getName()); - private final LanguageModel languageModel; - @Inject - public Generator(GeneratorConfig config, ComponentRegistry languageModels) { - this.languageModel = findLanguageModel(config.providerId(), languageModels); - } - - private LanguageModel findLanguageModel(String providerId, ComponentRegistry languageModels) +/** + * Provide utilities to implement Generator interface. + * It is used by language models as well as other generators. + */ +public class GeneratorUtils { + public static LanguageModel findLanguageModel(String providerId, ComponentRegistry languageModels, Logger log) throws IllegalArgumentException { if (languageModels.allComponents().isEmpty()) { @@ -52,8 +46,9 @@ private LanguageModel findLanguageModel(String providerId, ComponentRegistry ""); var promptObj = StringPrompt.from(prompt); var completions = languageModel.complete(promptObj, options); diff --git a/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java b/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java new file mode 100644 index 000000000000..ce33c49a3972 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java @@ -0,0 +1,29 @@ +package ai.vespa.generative; + +import ai.vespa.llm.LanguageModel; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.document.DataType; + +import java.util.logging.Logger; + +/** + * A generator that uses a language model to generate text. + * Unlike using a language model directly, this is supposed to be extended with configurable parameters, + * e.g. prompt template, postprocessors, etc. + */ +public class LanguageModelGenerator extends AbstractComponent implements com.yahoo.language.process.Generator { + private static final Logger logger = Logger.getLogger(LanguageModelGenerator.class.getName()); + private final LanguageModel languageModel; + + @Inject + public LanguageModelGenerator(LanguageModelGeneratorConfig config, ComponentRegistry languageModels) { + this.languageModel = GeneratorUtils.findLanguageModel(config.providerId(), languageModels, logger); + } + + @Override + public String generate(String prompt, Context context, DataType dataType) { + return GeneratorUtils.generate(prompt, languageModel); + } +} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java index bbb82db71391..5fe3e9e38930 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.clients; +import ai.vespa.generative.GeneratorUtils; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; import ai.vespa.llm.LanguageModelException; @@ -8,6 +9,8 @@ import ai.vespa.llm.completion.Prompt; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; +import com.yahoo.document.DataType; +import com.yahoo.language.process.Generator; import de.kherud.llama.LlamaModel; import de.kherud.llama.ModelParameters; @@ -31,7 +34,7 @@ * * @author lesters */ -public class LocalLLM extends AbstractComponent implements LanguageModel { +public class LocalLLM extends AbstractComponent implements LanguageModel, Generator { private final static Logger logger = Logger.getLogger(LocalLLM.class.getName()); @@ -152,4 +155,8 @@ private String rejectedExecutionReason(String prepend) { } + @Override + public String generate(String prompt, Context context, DataType dataType) { + return GeneratorUtils.generate(prompt, this); + } } diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java index 82e19d47c927..8a977b2d331c 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.clients; +import ai.vespa.generative.GeneratorUtils; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.client.openai.OpenAiClient; import ai.vespa.llm.completion.Completion; @@ -8,6 +9,8 @@ import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; import com.yahoo.container.jdisc.secretstore.SecretStore; +import com.yahoo.document.DataType; +import com.yahoo.language.process.Generator; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -19,7 +22,7 @@ * @author lesters */ @Beta -public class OpenAI extends ConfigurableLanguageModel { +public class OpenAI extends ConfigurableLanguageModel implements Generator { private final OpenAiClient client; @@ -44,5 +47,10 @@ public CompletableFuture completeAsync(Prompt prompt, setEndpoint(parameters); return client.completeAsync(prompt, parameters, consumer); } + + @Override + public String generate(String prompt, Context context, DataType dataType) { + return GeneratorUtils.generate(prompt, this); + } } diff --git a/model-integration/src/main/resources/configdefinitions/generator.def b/model-integration/src/main/resources/configdefinitions/language-model-generator.def similarity index 100% rename from model-integration/src/main/resources/configdefinitions/generator.def rename to model-integration/src/main/resources/configdefinitions/language-model-generator.def diff --git a/model-integration/src/test/java/ai/vespa/generative/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java similarity index 85% rename from model-integration/src/test/java/ai/vespa/generative/GeneratorTest.java rename to model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java index 2ef5139e7849..9d95fd8d84fa 100644 --- a/model-integration/src/test/java/ai/vespa/generative/GeneratorTest.java +++ b/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java @@ -18,30 +18,30 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -public class GeneratorTest { +public class LanguageModelGeneratorTest { @Test public void testGeneration() { LanguageModel languageModel1 = new RepeatMockLanguageModel(1); LanguageModel languageModel2 = new RepeatMockLanguageModel(2); var languageModels = Map.of("mock1", languageModel1, "mock2", languageModel2); - var config1 = new GeneratorConfig.Builder().providerId("mock1").build(); + var config1 = new LanguageModelGeneratorConfig.Builder().providerId("mock1").build(); var generator1 = createGenerator(config1, languageModels); var context = new com.yahoo.language.process.Generator.Context("schema.indexing"); var result1 = generator1.generate("hello", context, DataType.STRING); assertEquals("hello", result1); - var config2 = new GeneratorConfig.Builder().providerId("mock2").build(); + var config2 = new LanguageModelGeneratorConfig.Builder().providerId("mock2").build(); var generator2 = createGenerator(config2, Map.of("mock1", languageModel1, "mock2", languageModel2)); var result2 = generator2.generate("hello", context, DataType.STRING); assertEquals("hello hello", result2); } - private static Generator createGenerator(GeneratorConfig config, Map languageModels) { + private static LanguageModelGenerator createGenerator(LanguageModelGeneratorConfig config, Map languageModels) { ComponentRegistry models = new ComponentRegistry<>(); languageModels.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); models.freeze(); - return new Generator(config, models); + return new LanguageModelGenerator(config, models); } public static class RepeatMockLanguageModel implements LanguageModel { From 82ff3a9277e2d2a16cead103ca56d8cf8e5a53f3 Mon Sep 17 00:00:00 2001 From: lesters Date: Thu, 14 Nov 2024 11:52:14 +0100 Subject: [PATCH 05/12] Fix build --- config-model/src/main/javacc/SchemaParser.jj | 2 +- .../docprocs/indexing/ScriptManager.java | 2 +- linguistics/abi-spec.json | 55 +++++++++++++++++++ linguistics/pom.xml | 13 +++-- model-integration/abi-spec.json | 13 +++-- 5 files changed, 73 insertions(+), 12 deletions(-) diff --git a/config-model/src/main/javacc/SchemaParser.jj b/config-model/src/main/javacc/SchemaParser.jj index 7403c5cfc0fe..0e2275481364 100644 --- a/config-model/src/main/javacc/SchemaParser.jj +++ b/config-model/src/main/javacc/SchemaParser.jj @@ -83,7 +83,7 @@ public class SchemaParser { */ @SuppressWarnings("deprecation") private IndexingOperation newIndexingOperation(boolean multiline) throws ParseException { - return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); } /** diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java index 3088083912bc..8f2ec9c288e5 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java @@ -70,7 +70,7 @@ private static Map> createScriptsMap(Docume Linguistics linguistics, Map embedders) { Map> documentFieldScripts = new HashMap<>(config.ilscript().size()); - ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders); + ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders, null); parserContext.getAnnotatorConfig().setMaxTermOccurrences(config.maxtermoccurrences()); parserContext.getAnnotatorConfig().setMaxTokenizeLength(config.fieldmatchmaxlength()); diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index ba5bc3c79702..2f2fec6d257a 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -403,6 +403,61 @@ "public static final com.yahoo.language.process.Embedder throwsOnUse" ] }, + "com.yahoo.language.process.Generator$Context" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void (java.lang.String)", + "public void (java.lang.String, java.util.Map)", + "public com.yahoo.language.process.Generator$Context copy()", + "public com.yahoo.language.Language getLanguage()", + "public com.yahoo.language.process.Generator$Context setLanguage(com.yahoo.language.Language)", + "public java.lang.String getDestination()", + "public com.yahoo.language.process.Generator$Context setDestination(java.lang.String)", + "public java.lang.String getGeneratorId()", + "public com.yahoo.language.process.Generator$Context setGeneratorId(java.lang.String)", + "public void putCachedValue(java.lang.Object, java.lang.Object)", + "public java.lang.Object getCachedValue(java.lang.Object)", + "public java.lang.Object computeCachedValueIfAbsent(java.lang.Object, java.util.function.Supplier)" + ], + "fields" : [ ] + }, + "com.yahoo.language.process.Generator$FailingGenerator" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.language.process.Generator" + ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void ()", + "public void (java.lang.String)", + "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context, com.yahoo.document.DataType)" + ], + "fields" : [ ] + }, + "com.yahoo.language.process.Generator" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public", + "interface", + "abstract" + ], + "methods" : [ + "public java.util.Map asMap()", + "public java.util.Map asMap(java.lang.String)", + "public abstract java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context, com.yahoo.document.DataType)" + ], + "fields" : [ + "public static final java.lang.String defaultGeneratorId", + "public static final com.yahoo.language.process.Generator throwsOnUse" + ] + }, "com.yahoo.language.process.GramSplitter$Gram" : { "superClass" : "java.lang.Object", "interfaces" : [ ], diff --git a/linguistics/pom.xml b/linguistics/pom.xml index a23b2ce8e9bc..e6947de3049b 100644 --- a/linguistics/pom.xml +++ b/linguistics/pom.xml @@ -68,6 +68,13 @@ provided + + com.yahoo.vespa + document + ${project.version} + provided + + junit @@ -89,12 +96,6 @@ mockito-core test - - com.yahoo.vespa - document - 8.424.11 - compile - diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index 31f2b64d728c..c8381d018f2f 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -157,7 +157,8 @@ "ai.vespa.llm.clients.LocalLLM" : { "superClass" : "com.yahoo.component.AbstractComponent", "interfaces" : [ - "ai.vespa.llm.LanguageModel" + "ai.vespa.llm.LanguageModel", + "com.yahoo.language.process.Generator" ], "attributes" : [ "public" @@ -166,20 +167,24 @@ "public void (ai.vespa.llm.clients.LlmLocalClientConfig)", "public void deconstruct()", "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)", + "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context, com.yahoo.document.DataType)" ], "fields" : [ ] }, "ai.vespa.llm.clients.OpenAI" : { "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", - "interfaces" : [ ], + "interfaces" : [ + "com.yahoo.language.process.Generator" + ], "attributes" : [ "public" ], "methods" : [ "public void (ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)", + "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context, com.yahoo.document.DataType)" ], "fields" : [ ] }, From baccc993416b32697c57efa40321909a9e550c5f Mon Sep 17 00:00:00 2001 From: lesters Date: Thu, 14 Nov 2024 12:05:33 +0100 Subject: [PATCH 06/12] Remove dependency on document in linguistics --- .../indexinglanguage/expressions/GenerateExpression.java | 3 +-- .../vespa/indexinglanguage/GeneratorScriptTester.java | 2 +- linguistics/abi-spec.json | 4 ++-- linguistics/pom.xml | 7 ------- .../main/java/com/yahoo/language/process/Generator.java | 5 ++--- model-integration/abi-spec.json | 4 ++-- .../java/ai/vespa/generative/LanguageModelGenerator.java | 2 +- .../src/main/java/ai/vespa/llm/clients/LocalLLM.java | 2 +- .../src/main/java/ai/vespa/llm/clients/OpenAI.java | 2 +- .../ai/vespa/generative/LanguageModelGeneratorTest.java | 4 ++-- 10 files changed, 13 insertions(+), 22 deletions(-) diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java index 46b7d0f3c9ca..34c7c22399a5 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java @@ -119,8 +119,7 @@ private String generate(String input, DataType targetType, ExecutionContext cont input, new Generator.Context(destination, context.getCache()) .setLanguage(context.resolveLanguage(linguistics)) - .setGeneratorId(generatorId), - targetType + .setGeneratorId(generatorId) ); } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java index adfc94d8788e..7492f0f578f4 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java @@ -80,7 +80,7 @@ public RepeatMockGenerator(String expectedDestination, int repetitions) { this.repetitions = repetitions; } - public String generate(String prompt, Context context, DataType dataType) { + public String generate(String prompt, Context context) { var stringBuilder = new StringBuilder(); for (int i = 0; i < repetitions; i++) { diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 2f2fec6d257a..72975c62bab7 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -436,7 +436,7 @@ "methods" : [ "public void ()", "public void (java.lang.String)", - "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context, com.yahoo.document.DataType)" + "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context)" ], "fields" : [ ] }, @@ -451,7 +451,7 @@ "methods" : [ "public java.util.Map asMap()", "public java.util.Map asMap(java.lang.String)", - "public abstract java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context, com.yahoo.document.DataType)" + "public abstract java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context)" ], "fields" : [ "public static final java.lang.String defaultGeneratorId", diff --git a/linguistics/pom.xml b/linguistics/pom.xml index e6947de3049b..40767bfff268 100644 --- a/linguistics/pom.xml +++ b/linguistics/pom.xml @@ -68,13 +68,6 @@ provided - - com.yahoo.vespa - document - ${project.version} - provided - - junit diff --git a/linguistics/src/main/java/com/yahoo/language/process/Generator.java b/linguistics/src/main/java/com/yahoo/language/process/Generator.java index e1fe84a6b437..50cb89f0ad93 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Generator.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Generator.java @@ -1,7 +1,6 @@ package com.yahoo.language.process; import com.yahoo.collections.LazyMap; -import com.yahoo.document.DataType; import com.yahoo.language.Language; import java.util.Map; @@ -26,7 +25,7 @@ default Map asMap(String name) { return Map.of(name, this); } - String generate(String prompt, Context context, DataType dataType); + String generate(String prompt, Context context); class Context { @@ -121,7 +120,7 @@ public FailingGenerator(String message) { this.message = message; } - public String generate(String prompt, Context context, DataType dataType) { + public String generate(String prompt, Context context) { throw new IllegalStateException(message); } } diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index c8381d018f2f..2614acb60320 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -168,7 +168,7 @@ "public void deconstruct()", "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)", - "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context, com.yahoo.document.DataType)" + "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context)" ], "fields" : [ ] }, @@ -184,7 +184,7 @@ "public void (ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)", - "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context, com.yahoo.document.DataType)" + "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context)" ], "fields" : [ ] }, diff --git a/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java b/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java index ce33c49a3972..b2ba011b4f13 100644 --- a/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java +++ b/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java @@ -23,7 +23,7 @@ public LanguageModelGenerator(LanguageModelGeneratorConfig config, ComponentRegi } @Override - public String generate(String prompt, Context context, DataType dataType) { + public String generate(String prompt, Context context) { return GeneratorUtils.generate(prompt, languageModel); } } diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java index 5fe3e9e38930..c77ad25a00ee 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -156,7 +156,7 @@ private String rejectedExecutionReason(String prepend) { @Override - public String generate(String prompt, Context context, DataType dataType) { + public String generate(String prompt, Context context) { return GeneratorUtils.generate(prompt, this); } } diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java index 8a977b2d331c..a9047ec3f268 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -49,7 +49,7 @@ public CompletableFuture completeAsync(Prompt prompt, } @Override - public String generate(String prompt, Context context, DataType dataType) { + public String generate(String prompt, Context context) { return GeneratorUtils.generate(prompt, this); } } diff --git a/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java b/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java index 9d95fd8d84fa..8a412394d5d7 100644 --- a/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java +++ b/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java @@ -28,12 +28,12 @@ public void testGeneration() { var config1 = new LanguageModelGeneratorConfig.Builder().providerId("mock1").build(); var generator1 = createGenerator(config1, languageModels); var context = new com.yahoo.language.process.Generator.Context("schema.indexing"); - var result1 = generator1.generate("hello", context, DataType.STRING); + var result1 = generator1.generate("hello", context); assertEquals("hello", result1); var config2 = new LanguageModelGeneratorConfig.Builder().providerId("mock2").build(); var generator2 = createGenerator(config2, Map.of("mock1", languageModel1, "mock2", languageModel2)); - var result2 = generator2.generate("hello", context, DataType.STRING); + var result2 = generator2.generate("hello", context); assertEquals("hello hello", result2); } From ac9357b37b40b8497b0fad2d236a90fa8e0a8161 Mon Sep 17 00:00:00 2001 From: lesters Date: Thu, 14 Nov 2024 13:50:11 +0100 Subject: [PATCH 07/12] Fix tests --- .../docprocs/indexing/ScriptManager.java | 2 +- .../ccc/indexinglanguage/IndexingParser.ccc | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java index 8f2ec9c288e5..0110a578b30c 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java @@ -70,7 +70,7 @@ private static Map> createScriptsMap(Docume Linguistics linguistics, Map embedders) { Map> documentFieldScripts = new HashMap<>(config.ilscript().size()); - ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders, null); + ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders, Collections.emptyMap()); parserContext.getAnnotatorConfig().setMaxTermOccurrences(config.maxtermoccurrences()); parserContext.getAnnotatorConfig().setMaxTokenizeLength(config.fieldmatchmaxlength()); diff --git a/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc b/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc index f818bf20a6b8..ab647242fc17 100644 --- a/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc +++ b/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc @@ -190,6 +190,7 @@ TOKEN : | | | + | | | | @@ -347,6 +348,7 @@ Expression value() : val = exactExp() | val = flattenExp() | val = forEachExp() | + val = generateExp() | val = getFieldExp() | val = getVarExp() | val = guardExp() | @@ -485,6 +487,23 @@ Expression forEachExp() : { return new ForEachExpression(val); } ; +Expression generateExp() : +{ + String generatorId = ""; + String generatorArgument; + List generatorArguments = new ArrayList(); +} + + ( + [ SCAN((identifierStr)+) => (generatorId = identifierStr()) ] + ( SCAN((identifierStr)+) => (generatorArgument = identifierStr()) { generatorArguments.add(generatorArgument); } )* + ) + { + // return new GenerateExpression(linguistics, generators, generatorId, generatorArguments); + return null; + } +; + Expression getFieldExp() : { String val; @@ -876,6 +895,7 @@ String identifierStr() : | | | + | | | | From 8d913887a1cc88ce6c1bb2a9bb0e5fe7557ab15a Mon Sep 17 00:00:00 2001 From: lesters Date: Thu, 14 Nov 2024 14:08:42 +0100 Subject: [PATCH 08/12] Set input and output types for generate expression --- .../expressions/GenerateExpression.java | 31 ++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java index 34c7c22399a5..63c57038097f 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java @@ -11,6 +11,11 @@ import java.util.List; import java.util.Map; +/** + * Generates a value using the configured Generator component + * + * @author glebashnik + */ public class GenerateExpression extends Expression { private final Linguistics linguistics; private final Generator generator; @@ -29,7 +34,7 @@ public GenerateExpression( String generatorId, List generatorArguments ) { - super(null); + super(DataType.STRING); this.linguistics = linguistics; this.generatorId = generatorId; this.generatorArguments = List.copyOf(generatorArguments); @@ -56,25 +61,13 @@ else if ( ! generators.containsKey(generatorId)) { } @Override - public DataType setInputType(DataType type, VerificationContext context) { - // TODO: Not sure if this implementation of the methods is correct, needs careful review. - super.setInputType(type, context); - - if (type == DataType.STRING) - throw new IllegalArgumentException("generate requires a string input type, but got " + type); - - return DataType.STRING; + public DataType setInputType(DataType inputType, VerificationContext context) { + return super.setInputType(inputType, DataType.STRING, context); } @Override - public DataType setOutputType(DataType type, VerificationContext context) { - // TODO: Not sure if this implementation of the methods is correct, needs careful review. - super.setOutputType(type, type, context); - - if (type != DataType.STRING) - throw new IllegalArgumentException("generate requires a string input type, but got " + type); - - return DataType.STRING; + public DataType setOutputType(DataType outputType, VerificationContext context) { + return super.setOutputType(DataType.STRING, outputType, context); } @Override @@ -86,10 +79,6 @@ public void setStatementOutput(DocumentType documentType, Field field) { @Override protected void doVerify(VerificationContext context) { targetType = getOutputType(context); - - if (!validTarget(targetType)) - throw new VerificationException(this, "The generate target field must be a String"); - context.setCurrentType(createdOutputType()); } From 420b943531e749806fee6a66ea22781c152eaece Mon Sep 17 00:00:00 2001 From: lesters Date: Thu, 14 Nov 2024 14:35:47 +0100 Subject: [PATCH 09/12] Update GenerateExpression after merge with master branch --- .../vespa/indexinglanguage/expressions/GenerateExpression.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java index 63c57038097f..1d5f5e759749 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java @@ -67,7 +67,7 @@ public DataType setInputType(DataType inputType, VerificationContext context) { @Override public DataType setOutputType(DataType outputType, VerificationContext context) { - return super.setOutputType(DataType.STRING, outputType, context); + return super.setOutputType(DataType.STRING, outputType, null, context); } @Override From cc0218abea626c824490113980596c63afd28b5f Mon Sep 17 00:00:00 2001 From: lesters Date: Wed, 20 Nov 2024 14:16:59 +0100 Subject: [PATCH 10/12] Wire in generators to indexing processor --- .../ApplicationContainerCluster.java | 1 + .../provider/DefaultGeneratorProvider.java | 26 +++++++++++++++++++ .../docprocs/indexing/IndexingProcessor.java | 17 +++++++----- .../docprocs/indexing/ScriptManager.java | 10 ++++--- .../indexing/IndexingProcessorTestCase.java | 1 + .../indexing/ScriptManagerTestCase.java | 9 ++++--- .../com/yahoo/language/process/Generator.java | 1 + 7 files changed, 51 insertions(+), 14 deletions(-) create mode 100644 container-core/src/main/java/com/yahoo/language/provider/DefaultGeneratorProvider.java diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java index 2c865e8ce859..dbe9b8faf03d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java @@ -123,6 +123,7 @@ public ApplicationContainerCluster(TreeConfigProducer parent, String configSu addSimpleComponent("com.yahoo.language.provider.DefaultLinguisticsProvider"); addSimpleComponent("com.yahoo.language.provider.DefaultEmbedderProvider"); + addSimpleComponent("com.yahoo.language.provider.DefaultGeneratorProvider"); addSimpleComponent("com.yahoo.container.jdisc.SecretStoreProvider"); addSimpleComponent("com.yahoo.container.jdisc.CertificateStoreProvider"); addSimpleComponent("com.yahoo.container.jdisc.AthenzIdentityProviderProvider"); diff --git a/container-core/src/main/java/com/yahoo/language/provider/DefaultGeneratorProvider.java b/container-core/src/main/java/com/yahoo/language/provider/DefaultGeneratorProvider.java new file mode 100644 index 000000000000..5ce8fa6a2719 --- /dev/null +++ b/container-core/src/main/java/com/yahoo/language/provider/DefaultGeneratorProvider.java @@ -0,0 +1,26 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.provider; + +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.di.componentgraph.Provider; +import com.yahoo.language.process.Generator; + +/** + * Provides the default generator implementation if no generator component has been explicitly configured + * (dependency injection will fall back to providers if no components of the requested type is found). + * + * @author lesters + */ +@SuppressWarnings("unused") // Injected +public class DefaultGeneratorProvider implements Provider { + + @Inject + public DefaultGeneratorProvider() { } + + @Override + public Generator get() { return Generator.throwsOnUse; } + + @Override + public void deconstruct() {} + +} diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java index dec87c3ab4ae..5cd81d728400 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java @@ -22,7 +22,9 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.provider.DefaultEmbedderProvider; +import com.yahoo.language.provider.DefaultGeneratorProvider; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.indexinglanguage.AdapterFactory; import com.yahoo.vespa.indexinglanguage.SimpleAdapterFactory; @@ -58,9 +60,12 @@ public Expression selectExpression(DocumentType documentType, String fieldName) public IndexingProcessor(DocumentTypeManager documentTypeManager, IlscriptsConfig ilscriptsConfig, Linguistics linguistics, - ComponentRegistry embedders) { + ComponentRegistry embedders, + ComponentRegistry generators) { this.documentTypeManager = documentTypeManager; - scriptManager = new ScriptManager(this.documentTypeManager, ilscriptsConfig, linguistics, toMap(embedders)); + Map embedderMap = toMap(embedders, DefaultEmbedderProvider.class); + Map generatorMap = toMap(generators, DefaultGeneratorProvider.class); + scriptManager = new ScriptManager(this.documentTypeManager, ilscriptsConfig, linguistics, embedderMap, generatorMap); adapterFactory = new SimpleAdapterFactory(new ExpressionSelector()); } @@ -132,11 +137,11 @@ private void processRemove(DocumentRemove input, List out) { out.add(input); } - private Map toMap(ComponentRegistry embedders) { - var map = embedders.allComponentsById().entrySet().stream() - .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); + private Map toMap(ComponentRegistry registry, Class defaultProviderClass) { + var map = registry.allComponentsById().entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); if (map.size() > 1) { - map.remove(DefaultEmbedderProvider.class.getName()); + map.remove(defaultProviderClass.getName()); // Ideally, this should be handled by dependency injection, however for now this workaround is necessary. } return map; diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java index c280769fe1f9..a1a913a3c063 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java @@ -7,6 +7,7 @@ import java.util.logging.Level; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.indexinglanguage.ScriptParserContext; import com.yahoo.vespa.indexinglanguage.expressions.InputExpression; @@ -29,9 +30,9 @@ public class ScriptManager { private final DocumentTypeManager documentTypeManager; public ScriptManager(DocumentTypeManager documentTypeManager, IlscriptsConfig config, Linguistics linguistics, - Map embedders) { + Map embedders, Map generators) { this.documentTypeManager = documentTypeManager; - documentFieldScripts = createScriptsMap(documentTypeManager, config, linguistics, embedders); + documentFieldScripts = createScriptsMap(documentTypeManager, config, linguistics, embedders, generators); } private Map getScripts(DocumentType inputType) { @@ -68,9 +69,10 @@ public DocumentScript getScript(DocumentType inputType, String inputFieldName) { private static Map> createScriptsMap(DocumentTypeManager docTypeMgr, IlscriptsConfig config, Linguistics linguistics, - Map embedders) { + Map embedders, + Map generators) { Map> documentFieldScripts = new HashMap<>(config.ilscript().size()); - ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders, Collections.emptyMap()); + ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders, generators); parserContext.getAnnotatorConfig().setMaxTermOccurrences(config.maxtermoccurrences()); parserContext.getAnnotatorConfig().setMaxTokenizeLength(config.fieldmatchmaxlength()); diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java index df7c1a442d4a..0df3644062b2 100644 --- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java +++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTestCase.java @@ -209,6 +209,7 @@ private static IndexingProcessor newProcessor(String configId) { return new IndexingProcessor(new DocumentTypeManager(ConfigGetter.getConfig(DocumentmanagerConfig.class, configId)), ConfigGetter.getConfig(IlscriptsConfig.class, configId), new SimpleLinguistics(), + new ComponentRegistry<>(), new ComponentRegistry<>()); } } diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java index d6335aa1f4d1..f8933e043520 100644 --- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java +++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java @@ -4,6 +4,7 @@ import com.yahoo.document.DocumentType; import com.yahoo.document.DocumentTypeManager; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import org.junit.Test; @@ -27,7 +28,7 @@ public void requireThatScriptsAreAppliedToSubType() { IlscriptsConfig.Builder config = new IlscriptsConfig.Builder(); config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newssummary") .content("input title | index title")); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap()); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newsarticle"))); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @@ -41,7 +42,7 @@ public void requireThatScriptsAreAppliedToSuperType() { IlscriptsConfig.Builder config = new IlscriptsConfig.Builder(); config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newsarticle") .content("input title | index title")); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap()); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newssummary"))); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @@ -49,14 +50,14 @@ public void requireThatScriptsAreAppliedToSuperType() { @Test public void requireThatEmptyConfigurationDoesNotThrow() { var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg"); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap()); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @Test public void requireThatUnknownDocumentTypeReturnsNull() { var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg"); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap()); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); for (Iterator it = typeMgr.documentTypeIterator(); it.hasNext(); ) { assertNull(scriptMgr.getScript(it.next())); } diff --git a/linguistics/src/main/java/com/yahoo/language/process/Generator.java b/linguistics/src/main/java/com/yahoo/language/process/Generator.java index 50cb89f0ad93..d2b7affda7fe 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Generator.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Generator.java @@ -124,4 +124,5 @@ public String generate(String prompt, Context context) { throw new IllegalStateException(message); } } + } From 69baaa4ad1b482bc4c32bf47d2f274840606d8cd Mon Sep 17 00:00:00 2001 From: lesters Date: Thu, 21 Nov 2024 13:31:51 +0100 Subject: [PATCH 11/12] Use Prompt insteaad of String in Generators --- config-model-api/abi-spec.json | 4 ++-- .../indexinglanguage/expressions/GenerateExpression.java | 3 ++- .../yahoo/vespa/indexinglanguage/GeneratorScriptTester.java | 3 ++- linguistics/abi-spec.json | 4 ++-- .../src/main/java/com/yahoo/language/process/Generator.java | 5 +++-- model-integration/abi-spec.json | 4 ++-- .../src/main/java/ai/vespa/generative/GeneratorUtils.java | 6 +++--- .../java/ai/vespa/generative/LanguageModelGenerator.java | 3 ++- .../src/main/java/ai/vespa/llm/clients/LocalLLM.java | 2 +- .../src/main/java/ai/vespa/llm/clients/OpenAI.java | 2 +- .../ai/vespa/generative/LanguageModelGeneratorTest.java | 5 +++-- vespajlib/abi-spec.json | 4 ++-- 12 files changed, 25 insertions(+), 20 deletions(-) diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json index e61c2a196ba8..2a4864c0da5e 100644 --- a/config-model-api/abi-spec.json +++ b/config-model-api/abi-spec.json @@ -1812,8 +1812,8 @@ "public final java.lang.String toString()", "public final int hashCode()", "public final boolean equals(java.lang.Object)", - "public java.lang.String name()", - "public java.lang.String id()" + "public java.lang.String id()", + "public java.lang.String name()" ], "fields" : [ ] }, diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java index 1d5f5e759749..b744039ab74f 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java @@ -1,5 +1,6 @@ package com.yahoo.vespa.indexinglanguage.expressions; +import ai.vespa.llm.completion.StringPrompt; import com.yahoo.document.DataType; import com.yahoo.document.DocumentType; import com.yahoo.document.Field; @@ -105,7 +106,7 @@ private String generateSingleValue(ExecutionContext context) { private String generate(String input, DataType targetType, ExecutionContext context) { return generator.generate( - input, + StringPrompt.from(input), new Generator.Context(destination, context.getCache()) .setLanguage(context.resolveLanguage(linguistics)) .setGeneratorId(generatorId) diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java index 7492f0f578f4..7bbc900ac4dd 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java @@ -1,5 +1,6 @@ package com.yahoo.vespa.indexinglanguage; +import ai.vespa.llm.completion.Prompt; import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.expressions.Expression; @@ -80,7 +81,7 @@ public RepeatMockGenerator(String expectedDestination, int repetitions) { this.repetitions = repetitions; } - public String generate(String prompt, Context context) { + public String generate(Prompt prompt, Context context) { var stringBuilder = new StringBuilder(); for (int i = 0; i < repetitions; i++) { diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index 72975c62bab7..c37b52763c1c 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -436,7 +436,7 @@ "methods" : [ "public void ()", "public void (java.lang.String)", - "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context)" + "public java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" ], "fields" : [ ] }, @@ -451,7 +451,7 @@ "methods" : [ "public java.util.Map asMap()", "public java.util.Map asMap(java.lang.String)", - "public abstract java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context)" + "public abstract java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" ], "fields" : [ "public static final java.lang.String defaultGeneratorId", diff --git a/linguistics/src/main/java/com/yahoo/language/process/Generator.java b/linguistics/src/main/java/com/yahoo/language/process/Generator.java index d2b7affda7fe..e551af7937a9 100644 --- a/linguistics/src/main/java/com/yahoo/language/process/Generator.java +++ b/linguistics/src/main/java/com/yahoo/language/process/Generator.java @@ -1,5 +1,6 @@ package com.yahoo.language.process; +import ai.vespa.llm.completion.Prompt; import com.yahoo.collections.LazyMap; import com.yahoo.language.Language; @@ -25,7 +26,7 @@ default Map asMap(String name) { return Map.of(name, this); } - String generate(String prompt, Context context); + String generate(Prompt prompt, Context context); class Context { @@ -120,7 +121,7 @@ public FailingGenerator(String message) { this.message = message; } - public String generate(String prompt, Context context) { + public String generate(Prompt prompt, Context context) { throw new IllegalStateException(message); } } diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index 2614acb60320..56655e362477 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -168,7 +168,7 @@ "public void deconstruct()", "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)", - "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context)" + "public java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" ], "fields" : [ ] }, @@ -184,7 +184,7 @@ "public void (ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)", - "public java.lang.String generate(java.lang.String, com.yahoo.language.process.Generator$Context)" + "public java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" ], "fields" : [ ] }, diff --git a/model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java b/model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java index 33e22b2aaf55..15884415b388 100644 --- a/model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java +++ b/model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java @@ -2,6 +2,7 @@ import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Prompt; import ai.vespa.llm.completion.StringPrompt; import com.yahoo.component.ComponentId; import com.yahoo.component.provider.ComponentRegistry; @@ -47,11 +48,10 @@ public static LanguageModel findLanguageModel(String providerId, ComponentRegist } public static String generate( - String prompt, LanguageModel languageModel) + Prompt prompt, LanguageModel languageModel) { var options = new InferenceParameters(s -> ""); - var promptObj = StringPrompt.from(prompt); - var completions = languageModel.complete(promptObj, options); + var completions = languageModel.complete(prompt, options); var firstCompletion = completions.get(0); return firstCompletion.text(); } diff --git a/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java b/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java index b2ba011b4f13..514c1d8d0179 100644 --- a/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java +++ b/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java @@ -1,6 +1,7 @@ package ai.vespa.generative; import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Prompt; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import com.yahoo.component.provider.ComponentRegistry; @@ -23,7 +24,7 @@ public LanguageModelGenerator(LanguageModelGeneratorConfig config, ComponentRegi } @Override - public String generate(String prompt, Context context) { + public String generate(Prompt prompt, Context context) { return GeneratorUtils.generate(prompt, languageModel); } } diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java index c77ad25a00ee..c4d1bd522b2c 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -156,7 +156,7 @@ private String rejectedExecutionReason(String prepend) { @Override - public String generate(String prompt, Context context) { + public String generate(Prompt prompt, Context context) { return GeneratorUtils.generate(prompt, this); } } diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java index a9047ec3f268..3e0965e3c6d2 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -49,7 +49,7 @@ public CompletableFuture completeAsync(Prompt prompt, } @Override - public String generate(String prompt, Context context) { + public String generate(Prompt prompt, Context context) { return GeneratorUtils.generate(prompt, this); } } diff --git a/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java b/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java index 8a412394d5d7..f25cac66989b 100644 --- a/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java +++ b/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java @@ -4,6 +4,7 @@ import ai.vespa.llm.LanguageModel; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; import com.yahoo.component.ComponentId; import com.yahoo.component.provider.ComponentRegistry; @@ -28,12 +29,12 @@ public void testGeneration() { var config1 = new LanguageModelGeneratorConfig.Builder().providerId("mock1").build(); var generator1 = createGenerator(config1, languageModels); var context = new com.yahoo.language.process.Generator.Context("schema.indexing"); - var result1 = generator1.generate("hello", context); + var result1 = generator1.generate(StringPrompt.from("hello"), context); assertEquals("hello", result1); var config2 = new LanguageModelGeneratorConfig.Builder().providerId("mock2").build(); var generator2 = createGenerator(config2, Map.of("mock1", languageModel1, "mock2", languageModel2)); - var result2 = generator2.generate("hello", context); + var result2 = generator2.generate(StringPrompt.from("hello"), context); assertEquals("hello hello", result2); } diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index a8b21461ef95..d7e062ff1d8d 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -3827,14 +3827,14 @@ "public static java.lang.String toMessageString(java.lang.Throwable)", "public static java.util.Optional findCause(java.lang.Throwable, java.lang.Class)", "public static void uncheck(com.yahoo.yolean.Exceptions$RunnableThrowingIOException)", - "public static void uncheckInterrupted(com.yahoo.yolean.Exceptions$RunnableThrowingInterruptedException)", - "public static void uncheckInterruptedAndRestoreFlag(com.yahoo.yolean.Exceptions$RunnableThrowingInterruptedException)", "public static varargs void uncheck(com.yahoo.yolean.Exceptions$RunnableThrowingIOException, java.lang.String, java.lang.String[])", "public static void uncheckAndIgnore(com.yahoo.yolean.Exceptions$RunnableThrowingIOException, java.lang.Class)", "public static java.util.function.Function uncheck(com.yahoo.yolean.Exceptions$FunctionThrowingIOException)", "public static java.lang.Object uncheck(com.yahoo.yolean.Exceptions$SupplierThrowingIOException)", "public static varargs java.lang.Object uncheck(com.yahoo.yolean.Exceptions$SupplierThrowingIOException, java.lang.String, java.lang.String[])", "public static java.lang.Object uncheckAndIgnore(com.yahoo.yolean.Exceptions$SupplierThrowingIOException, java.lang.Class)", + "public static void uncheckInterrupted(com.yahoo.yolean.Exceptions$RunnableThrowingInterruptedException)", + "public static void uncheckInterruptedAndRestoreFlag(com.yahoo.yolean.Exceptions$RunnableThrowingInterruptedException)", "public static java.lang.Object uncheckInterrupted(com.yahoo.yolean.Exceptions$SupplierThrowingInterruptedException)", "public static java.lang.RuntimeException throwUnchecked(java.lang.Throwable)" ], From f4869ecc49732e021a8b0a59d3566ed8559a0b34 Mon Sep 17 00:00:00 2001 From: lesters Date: Fri, 29 Nov 2024 11:08:43 +0100 Subject: [PATCH 12/12] ConfigurableLanguageModel implements Generator --- model-integration/abi-spec.json | 13 ++++++------- .../llm/clients/ConfigurableLanguageModel.java | 14 +++++++++++++- .../main/java/ai/vespa/llm/clients/LocalLLM.java | 1 - .../src/main/java/ai/vespa/llm/clients/OpenAI.java | 7 +------ quickbuild.sh | 2 ++ 5 files changed, 22 insertions(+), 15 deletions(-) create mode 100755 quickbuild.sh diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index e9ee91aa7a4f..06d634f0a464 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -2,7 +2,8 @@ "ai.vespa.llm.clients.ConfigurableLanguageModel" : { "superClass" : "java.lang.Object", "interfaces" : [ - "ai.vespa.llm.LanguageModel" + "ai.vespa.llm.LanguageModel", + "com.yahoo.language.process.Generator" ], "attributes" : [ "public", @@ -14,7 +15,8 @@ "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", "protected void setApiKey(ai.vespa.llm.InferenceParameters)", "protected java.lang.String getEndpoint()", - "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" + "protected void setEndpoint(ai.vespa.llm.InferenceParameters)", + "public java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" ], "fields" : [ ] }, @@ -174,17 +176,14 @@ }, "ai.vespa.llm.clients.OpenAI" : { "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", - "interfaces" : [ - "com.yahoo.language.process.Generator" - ], + "interfaces" : [ ], "attributes" : [ "public" ], "methods" : [ "public void (ai.vespa.llm.clients.LlmClientConfig, ai.vespa.secret.Secrets)", "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)", - "public java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" ], "fields" : [ ] }, diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java index 015b9195a258..493ca59ad3db 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -8,6 +8,7 @@ import ai.vespa.secret.Secrets; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; +import com.yahoo.language.process.Generator; import java.util.HashMap; import java.util.logging.Logger; @@ -19,7 +20,7 @@ * @author lesters */ @Beta -public abstract class ConfigurableLanguageModel implements LanguageModel { +public abstract class ConfigurableLanguageModel implements LanguageModel, Generator { private static final Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); @@ -77,4 +78,15 @@ protected void setEndpoint(InferenceParameters params) { } } + @Override + public String generate(Prompt prompt, Context context) { + var params = new HashMap(); + var options = new InferenceParameters(params::get); + setApiKey(options); + + var completions = complete(prompt, options); + var firstCompletion = completions.get(0); + return firstCompletion.text(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java index c4d1bd522b2c..061f3a605d9e 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -9,7 +9,6 @@ import ai.vespa.llm.completion.Prompt; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; -import com.yahoo.document.DataType; import com.yahoo.language.process.Generator; import de.kherud.llama.LlamaModel; import de.kherud.llama.ModelParameters; diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java index 57615180350b..a9b2f419e945 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -21,7 +21,7 @@ * @author lesters */ @Beta -public class OpenAI extends ConfigurableLanguageModel implements Generator { +public class OpenAI extends ConfigurableLanguageModel { private final OpenAiClient client; @@ -47,10 +47,5 @@ public CompletableFuture completeAsync(Prompt prompt, return client.completeAsync(prompt, parameters, consumer); } - @Override - public String generate(Prompt prompt, Context context) { - return GeneratorUtils.generate(prompt, this); - } - } diff --git a/quickbuild.sh b/quickbuild.sh new file mode 100755 index 000000000000..810f5e9a8798 --- /dev/null +++ b/quickbuild.sh @@ -0,0 +1,2 @@ +mvn clean install --threads 1C -Dmaven.javadoc.skip=true -Dmaven.source.skip=true -DskipTests "$@" +