diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json index e61c2a196ba8..2a4864c0da5e 100644 --- a/config-model-api/abi-spec.json +++ b/config-model-api/abi-spec.json @@ -1812,8 +1812,8 @@ "public final java.lang.String toString()", "public final int hashCode()", "public final boolean equals(java.lang.Object)", - "public java.lang.String name()", - "public java.lang.String id()" + "public java.lang.String id()", + "public java.lang.String name()" ], "fields" : [ ] }, diff --git a/config-model/src/main/java/com/yahoo/schema/document/SDField.java b/config-model/src/main/java/com/yahoo/schema/document/SDField.java index 2483fa476676..12a00652eaec 100644 --- a/config-model/src/main/java/com/yahoo/schema/document/SDField.java +++ b/config-model/src/main/java/com/yahoo/schema/document/SDField.java @@ -13,6 +13,7 @@ import com.yahoo.documentmodel.TemporaryUnknownType; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.schema.Index; import com.yahoo.schema.Schema; @@ -399,12 +400,13 @@ public boolean hasSingleAttribute() { /** Parse an indexing expression which will use the simple linguistics implementation suitable for testing */ public void parseIndexingScript(String schemaName, String script) { - parseIndexingScript(schemaName, script, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + parseIndexingScript(schemaName, script, new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); } - public void parseIndexingScript(String schemaName, String script, Linguistics linguistics, Map embedders) { + public void parseIndexingScript(String schemaName, String script, Linguistics linguistics, + Map embedders, Map generators) { try { - ScriptParserContext config = new ScriptParserContext(linguistics, embedders); + ScriptParserContext config = new ScriptParserContext(linguistics, embedders, generators); config.setInputStream(new IndexingInput(script)); setIndexingScript(schemaName, ScriptExpression.newInstance(config)); } catch (ParseException e) { diff --git a/config-model/src/main/java/com/yahoo/schema/fieldoperation/IndexingOperation.java b/config-model/src/main/java/com/yahoo/schema/fieldoperation/IndexingOperation.java index 11065f040ea5..b5d8f4ed369b 100644 --- a/config-model/src/main/java/com/yahoo/schema/fieldoperation/IndexingOperation.java +++ b/config-model/src/main/java/com/yahoo/schema/fieldoperation/IndexingOperation.java @@ -3,6 +3,7 @@ import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.schema.document.SDField; import com.yahoo.schema.parser.ParseException; @@ -34,13 +35,14 @@ public void apply(String schemaName, SDField field) { /** Creates an indexing operation which will use the simple linguistics implementation suitable for testing */ public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine) throws ParseException { - return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + return fromStream(input, multiLine, new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), + Generator.throwsOnUse.asMap()); } - public static IndexingOperation fromStream(SimpleCharStream input, boolean multiLine, - Linguistics linguistics, Map embedders) - throws ParseException { - ScriptParserContext config = new ScriptParserContext(linguistics, embedders); + public static IndexingOperation fromStream( + SimpleCharStream input, boolean multiLine, Linguistics linguistics, Map embedders, + Map generators) throws ParseException { + ScriptParserContext config = new ScriptParserContext(linguistics, embedders, generators); config.setAnnotatorConfig(new AnnotatorConfig()); config.setInputStream(input); ScriptExpression exp; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java index 187946099390..694e3c9a3f08 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java @@ -127,6 +127,7 @@ public ApplicationContainerCluster(TreeConfigProducer parent, String configSu addSimpleComponent("com.yahoo.language.provider.DefaultLinguisticsProvider"); addSimpleComponent("com.yahoo.language.provider.DefaultEmbedderProvider"); + addSimpleComponent("com.yahoo.language.provider.DefaultGeneratorProvider"); addSimpleComponent("com.yahoo.container.jdisc.SecretStoreProvider"); addSimpleComponent("com.yahoo.container.jdisc.CertificateStoreProvider"); addSimpleComponent("com.yahoo.container.jdisc.AthenzIdentityProviderProvider"); diff --git a/config-model/src/main/javacc/SchemaParser.jj b/config-model/src/main/javacc/SchemaParser.jj index c9eff88764f7..0e2275481364 100644 --- a/config-model/src/main/javacc/SchemaParser.jj +++ b/config-model/src/main/javacc/SchemaParser.jj @@ -17,6 +17,7 @@ import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.api.ModelContext; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.search.query.ranking.Diversity; import com.yahoo.schema.DistributableResource; @@ -82,7 +83,7 @@ public class SchemaParser { */ @SuppressWarnings("deprecation") private IndexingOperation newIndexingOperation(boolean multiline) throws ParseException { - return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + return newIndexingOperation(multiline, new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); } /** @@ -91,13 +92,15 @@ public class SchemaParser { * @param multiline Whether or not to allow multi-line expressions. * @param linguistics What to use for tokenizing. */ - private IndexingOperation newIndexingOperation(boolean multiline, Linguistics linguistics, Map embedders) throws ParseException { + private IndexingOperation newIndexingOperation( + boolean multiline, Linguistics linguistics, Map embedders, + Map generators) throws ParseException { SimpleCharStream input = (SimpleCharStream)token_source.input_stream; if (token.next != null) { input.backup(token.next.image.length()); } try { - return IndexingOperation.fromStream(input, multiline, linguistics, embedders); + return IndexingOperation.fromStream(input, multiline, linguistics, embedders, generators); } finally { token.next = null; jj_ntk = -1; diff --git a/container-core/src/main/java/com/yahoo/language/provider/DefaultGeneratorProvider.java b/container-core/src/main/java/com/yahoo/language/provider/DefaultGeneratorProvider.java new file mode 100644 index 000000000000..5ce8fa6a2719 --- /dev/null +++ b/container-core/src/main/java/com/yahoo/language/provider/DefaultGeneratorProvider.java @@ -0,0 +1,26 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.provider; + +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.di.componentgraph.Provider; +import com.yahoo.language.process.Generator; + +/** + * Provides the default generator implementation if no generator component has been explicitly configured + * (dependency injection will fall back to providers if no components of the requested type is found). + * + * @author lesters + */ +@SuppressWarnings("unused") // Injected +public class DefaultGeneratorProvider implements Provider { + + @Inject + public DefaultGeneratorProvider() { } + + @Override + public Generator get() { return Generator.throwsOnUse; } + + @Override + public void deconstruct() {} + +} diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java index dec87c3ab4ae..5cd81d728400 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/IndexingProcessor.java @@ -22,7 +22,9 @@ import com.yahoo.io.GrowableByteBuffer; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.provider.DefaultEmbedderProvider; +import com.yahoo.language.provider.DefaultGeneratorProvider; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.indexinglanguage.AdapterFactory; import com.yahoo.vespa.indexinglanguage.SimpleAdapterFactory; @@ -58,9 +60,12 @@ public Expression selectExpression(DocumentType documentType, String fieldName) public IndexingProcessor(DocumentTypeManager documentTypeManager, IlscriptsConfig ilscriptsConfig, Linguistics linguistics, - ComponentRegistry embedders) { + ComponentRegistry embedders, + ComponentRegistry generators) { this.documentTypeManager = documentTypeManager; - scriptManager = new ScriptManager(this.documentTypeManager, ilscriptsConfig, linguistics, toMap(embedders)); + Map embedderMap = toMap(embedders, DefaultEmbedderProvider.class); + Map generatorMap = toMap(generators, DefaultGeneratorProvider.class); + scriptManager = new ScriptManager(this.documentTypeManager, ilscriptsConfig, linguistics, embedderMap, generatorMap); adapterFactory = new SimpleAdapterFactory(new ExpressionSelector()); } @@ -132,11 +137,11 @@ private void processRemove(DocumentRemove input, List out) { out.add(input); } - private Map toMap(ComponentRegistry embedders) { - var map = embedders.allComponentsById().entrySet().stream() - .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); + private Map toMap(ComponentRegistry registry, Class defaultProviderClass) { + var map = registry.allComponentsById().entrySet().stream() + .collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue)); if (map.size() > 1) { - map.remove(DefaultEmbedderProvider.class.getName()); + map.remove(defaultProviderClass.getName()); // Ideally, this should be handled by dependency injection, however for now this workaround is necessary. } return map; diff --git a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java index 1fdf7e5fefdf..80f98010a95a 100644 --- a/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java +++ b/docprocs/src/main/java/com/yahoo/docprocs/indexing/ScriptManager.java @@ -6,6 +6,7 @@ import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import com.yahoo.vespa.indexinglanguage.ScriptParserContext; import com.yahoo.vespa.indexinglanguage.expressions.InputExpression; @@ -31,9 +32,9 @@ public class ScriptManager { private final DocumentTypeManager documentTypeManager; public ScriptManager(DocumentTypeManager documentTypeManager, IlscriptsConfig config, Linguistics linguistics, - Map embedders) { + Map embedders, Map generators) { this.documentTypeManager = documentTypeManager; - documentFieldScripts = createScriptsMap(documentTypeManager, config, linguistics, embedders); + documentFieldScripts = createScriptsMap(documentTypeManager, config, linguistics, embedders, generators); } private Map getScripts(DocumentType inputType) { @@ -70,9 +71,10 @@ public DocumentScript getScript(DocumentType inputType, String inputFieldName) { private static Map> createScriptsMap(DocumentTypeManager documentTypes, IlscriptsConfig config, Linguistics linguistics, - Map embedders) { + Map embedders, + Map generators) { Map> documentFieldScripts = new HashMap<>(config.ilscript().size()); - ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders); + ScriptParserContext parserContext = new ScriptParserContext(linguistics, embedders, generators); parserContext.getAnnotatorConfig().setMaxTermOccurrences(config.maxtermoccurrences()); parserContext.getAnnotatorConfig().setMaxTokenizeLength(config.fieldmatchmaxlength()); diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTester.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTester.java index 9a24861ebe06..eeb2ad7184fd 100644 --- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTester.java +++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/IndexingProcessorTester.java @@ -67,6 +67,7 @@ private static IndexingProcessor newProcessor(String configId) { return new IndexingProcessor(new DocumentTypeManager(ConfigGetter.getConfig(DocumentmanagerConfig.class, configId)), ConfigGetter.getConfig(IlscriptsConfig.class, configId), new SimpleLinguistics(), + new ComponentRegistry<>(), new ComponentRegistry<>()); } diff --git a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java index d6335aa1f4d1..f8933e043520 100644 --- a/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java +++ b/docprocs/src/test/java/com/yahoo/docprocs/indexing/ScriptManagerTestCase.java @@ -4,6 +4,7 @@ import com.yahoo.document.DocumentType; import com.yahoo.document.DocumentTypeManager; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.vespa.configdefinition.IlscriptsConfig; import org.junit.Test; @@ -27,7 +28,7 @@ public void requireThatScriptsAreAppliedToSubType() { IlscriptsConfig.Builder config = new IlscriptsConfig.Builder(); config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newssummary") .content("input title | index title")); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap()); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newsarticle"))); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @@ -41,7 +42,7 @@ public void requireThatScriptsAreAppliedToSuperType() { IlscriptsConfig.Builder config = new IlscriptsConfig.Builder(); config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("newsarticle") .content("input title | index title")); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap()); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(config), null, Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); assertNotNull(scriptMgr.getScript(typeMgr.getDocumentType("newssummary"))); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @@ -49,14 +50,14 @@ public void requireThatScriptsAreAppliedToSuperType() { @Test public void requireThatEmptyConfigurationDoesNotThrow() { var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg"); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap()); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); assertNull(scriptMgr.getScript(new DocumentType("unknown"))); } @Test public void requireThatUnknownDocumentTypeReturnsNull() { var typeMgr = DocumentTypeManager.fromFile("src/test/cfg/documentmanager_inherit.cfg"); - ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap()); + ScriptManager scriptMgr = new ScriptManager(typeMgr, new IlscriptsConfig(new IlscriptsConfig.Builder()), null, Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap()); for (Iterator it = typeMgr.documentTypeIterator(); it.hasNext(); ) { assertNull(scriptMgr.getScript(it.next())); } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java index f243f854c299..db14014b4a83 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParser.java @@ -47,6 +47,8 @@ private static T parse(ScriptParserContext context, Parse parser.setDefaultFieldName(context.getDefaultFieldName()); parser.setLinguistics(context.getLinguistcs()); parser.setEmbedders(context.getEmbedders()); + parser.setGenerators(context.getGenerators()); + try { return method.call(parser); } catch (ParseException e) { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java index 01c688af8e33..438eb8f15c43 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ScriptParserContext.java @@ -3,6 +3,7 @@ import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.vespa.indexinglanguage.parser.CharStream; @@ -17,12 +18,14 @@ public class ScriptParserContext { private AnnotatorConfig annotatorConfig = new AnnotatorConfig(); private Linguistics linguistics; private final Map embedders; + private final Map generators; private String defaultFieldName = null; private CharStream inputStream = null; - public ScriptParserContext(Linguistics linguistics, Map embedders) { + public ScriptParserContext(Linguistics linguistics, Map embedders, Map generators) { this.linguistics = linguistics; this.embedders = embedders; + this.generators = generators; } public AnnotatorConfig getAnnotatorConfig() { @@ -46,6 +49,9 @@ public ScriptParserContext setLinguistics(Linguistics linguistics) { public Map getEmbedders() { return Collections.unmodifiableMap(embedders); } + + public Map getGenerators() { return Collections.unmodifiableMap(generators); + } public String getDefaultFieldName() { return defaultFieldName; diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java index c1c2b5fc68fd..ef7e2e952f7a 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/Expression.java @@ -9,6 +9,7 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.*; import com.yahoo.vespa.indexinglanguage.parser.IndexingInput; @@ -301,7 +302,11 @@ public static Expression fromString(String expression) throws ParseException { } public static Expression fromString(String expression, Linguistics linguistics, Map embedders) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); + return newInstance(new ScriptParserContext(linguistics, embedders, Map.of()).setInputStream(new IndexingInput(expression))); + } + + public static Expression fromString(String expression, Linguistics linguistics, Map embedders, Map generators) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders, generators).setInputStream(new IndexingInput(expression))); } public static Expression newInstance(ScriptParserContext context) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java new file mode 100644 index 000000000000..b744039ab74f --- /dev/null +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/GenerateExpression.java @@ -0,0 +1,149 @@ +package com.yahoo.vespa.indexinglanguage.expressions; + +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.document.DataType; +import com.yahoo.document.DocumentType; +import com.yahoo.document.Field; +import com.yahoo.document.datatypes.StringFieldValue; +import com.yahoo.language.Linguistics; +import com.yahoo.language.process.Generator; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Generates a value using the configured Generator component + * + * @author glebashnik + */ +public class GenerateExpression extends Expression { + private final Linguistics linguistics; + private final Generator generator; + private final String generatorId; + private final List generatorArguments; + + /** The destination the generated value will be written to in the form [schema name].[field name] */ + private String destination; + + /** The target type we are generating into. */ + private DataType targetType; + + public GenerateExpression( + Linguistics linguistics, + Map generators, + String generatorId, + List generatorArguments + ) { + super(DataType.STRING); + this.linguistics = linguistics; + this.generatorId = generatorId; + this.generatorArguments = List.copyOf(generatorArguments); + + boolean generatorIdProvided = generatorId != null && !generatorId.isEmpty(); + + if (generators.isEmpty()) { + throw new IllegalStateException("No generators provided"); // should never happen + } + else if (generators.size() == 1 && ! generatorIdProvided) { + this.generator = generators.entrySet().stream().findFirst().get().getValue(); + } + else if (generators.size() > 1 && ! generatorIdProvided) { + this.generator = new Generator.FailingGenerator( + "Multiple generators are provided but no generator id is given. " + + "Valid generators are " + validGenerators(generators)); + } + else if ( ! generators.containsKey(generatorId)) { + this.generator = new Generator.FailingGenerator("Can't find generator '" + generatorId + "'. " + + "Valid generators are " + validGenerators(generators)); + } else { + this.generator = generators.get(generatorId); + } + } + + @Override + public DataType setInputType(DataType inputType, VerificationContext context) { + return super.setInputType(inputType, DataType.STRING, context); + } + + @Override + public DataType setOutputType(DataType outputType, VerificationContext context) { + return super.setOutputType(DataType.STRING, outputType, null, context); + } + + @Override + public void setStatementOutput(DocumentType documentType, Field field) { + targetType = field.getDataType(); + destination = documentType.getName() + "." + field.getName(); + } + + @Override + protected void doVerify(VerificationContext context) { + targetType = getOutputType(context); + context.setCurrentType(createdOutputType()); + } + + @Override + protected void doExecute(ExecutionContext context) { + if (context.getCurrentValue() == null) return; + + String output; + if (context.getCurrentValue().getDataType() == DataType.STRING) { + output = generateSingleValue(context); + } + else { + throw new IllegalArgumentException("Generate can only be done on string fields, not " + + context.getCurrentValue().getDataType()); + } + + context.setCurrentValue(new StringFieldValue(output)); + } + + private String generateSingleValue(ExecutionContext context) { + StringFieldValue input = (StringFieldValue)context.getCurrentValue(); + return generate(input.getString(), targetType, context); + } + + private String generate(String input, DataType targetType, ExecutionContext context) { + return generator.generate( + StringPrompt.from(input), + new Generator.Context(destination, context.getCache()) + .setLanguage(context.resolveLanguage(linguistics)) + .setGeneratorId(generatorId) + ); + } + + @Override + public DataType createdOutputType() { + return targetType; + } + + private boolean validTarget(DataType target) { + return target == DataType.STRING; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("generate"); + if (this.generatorId != null && !this.generatorId.isEmpty()) + sb.append(" ").append(this.generatorId); + generatorArguments.forEach(arg -> sb.append(" ").append(arg)); + return sb.toString(); + } + + @Override + public int hashCode() { return GenerateExpression.class.hashCode(); } + + @Override + public boolean equals(Object o) { + return o instanceof GenerateExpression; + } + + private static String validGenerators(Map generators) { + List generatorIds = new ArrayList<>(); + generators.forEach((key, value) -> generatorIds.add(key)); + generatorIds.sort(null); + return String.join(", ", generatorIds); + } +} diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java index 1c4e097b1f8a..fc54ce3b4a6c 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ScriptExpression.java @@ -5,6 +5,7 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.ExpressionConverter; import com.yahoo.vespa.indexinglanguage.ScriptParser; @@ -134,7 +135,13 @@ public static ScriptExpression fromString(String expression) throws ParseExcepti } public static ScriptExpression fromString(String expression, Linguistics linguistics, Map embedders) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); + return newInstance(new ScriptParserContext(linguistics, embedders, Map.of()).setInputStream(new IndexingInput(expression))); + } + + public static Expression fromString( + String expression, Linguistics linguistics, Map embedders, + Map generators) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders, generators).setInputStream(new IndexingInput(expression))); } public static ScriptExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java index 1e16f6e52646..bb3ee5a76339 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/StatementExpression.java @@ -4,6 +4,7 @@ import com.yahoo.document.DataType; import com.yahoo.language.Linguistics; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.ExpressionConverter; import com.yahoo.vespa.indexinglanguage.ScriptParser; @@ -160,7 +161,13 @@ public static StatementExpression fromString(String expression) throws ParseExce } public static StatementExpression fromString(String expression, Linguistics linguistics, Map embedders) throws ParseException { - return newInstance(new ScriptParserContext(linguistics, embedders).setInputStream(new IndexingInput(expression))); + return newInstance(new ScriptParserContext(linguistics, embedders, Map.of()).setInputStream(new IndexingInput(expression))); + } + + public static StatementExpression fromString( + String expression, Linguistics linguistics, Map embedders, + Map generators) throws ParseException { + return newInstance(new ScriptParserContext(linguistics, embedders, generators).setInputStream(new IndexingInput(expression))); } public static StatementExpression newInstance(ScriptParserContext config) throws ParseException { diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index ad247601dbef..2025ba535fb4 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -33,6 +33,7 @@ import com.yahoo.text.StringUtilities; import com.yahoo.vespa.indexinglanguage.expressions.*; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.Linguistics; /** @@ -43,6 +44,7 @@ public class IndexingParser { private String defaultFieldName; private Linguistics linguistics; private Map embedders; + private Map generators; private AnnotatorConfig annotatorCfg; public IndexingParser(String str) { @@ -64,6 +66,11 @@ public class IndexingParser { return this; } + public IndexingParser setGenerators(Map generators) { + this.generators = generators; + return this; + } + public IndexingParser setAnnotatorConfig(AnnotatorConfig cfg) { annotatorCfg = cfg; return this; @@ -158,6 +165,7 @@ TOKEN : | | | + | | | | @@ -306,6 +314,7 @@ Expression value() : val = clearStateExp() | val = echoExp() | val = embedExp() | + val = generateExp() | val = exactExp() | val = executionValueExp() | val = flattenExp() | @@ -412,6 +421,20 @@ Expression embedExp() : { return new EmbedExpression(linguistics, embedders, embedderId, embedderArguments); } } +Expression generateExp() : +{ + String generatorId = ""; + String generatorArgument; + List generatorArguments = new ArrayList(); +} +{ + ( + [ LOOKAHEAD(2) generatorId = identifier() ] + ( LOOKAHEAD(2) generatorArgument = identifier() { generatorArguments.add(generatorArgument); } )* + ) + { return new GenerateExpression(linguistics, generators, generatorId, generatorArguments); } +} + Expression exactExp() : { int maxTokenLength = annotatorCfg.getMaxTokenLength(); @@ -835,6 +858,7 @@ String identifier() : | | | + | | | | diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java index c4a53c1af683..e4171e60a709 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/EmbeddingScriptTester.java @@ -63,6 +63,34 @@ public void testStatement(String expressionString, String input, String targetTe } } + public void testStatement2(String expressionString, String input, String targetTensorType, String expected) { + var expression = expressionFrom(expressionString); + TensorType tensorType = TensorType.fromSpec(targetTensorType); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myText", DataType.STRING)); + var tensorField = new Field("myTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + if (input != null) + adapter.setValue("myText", new StringFieldValue(input)); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(TensorDataType.class, expression.verify(verificationContext).getClass()); + + ExecutionContext context = new ExecutionContext(adapter); + expression.execute(context); + if (input == null) { + assertFalse(adapter.values.containsKey("myTensor")); + } + else { + assertTrue(adapter.values.containsKey("myTensor")); + assertEquals(Tensor.from(tensorType, expected), + ((TensorFieldValue) adapter.values.get("myTensor")).getTensor().get()); + } + } + public void testStatementThrows(String expressionString, String input, String expectedMessage) { try { testStatement(expressionString, input, null); diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java new file mode 100644 index 000000000000..1c733d79fdd9 --- /dev/null +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTestCase.java @@ -0,0 +1,41 @@ +package com.yahoo.vespa.indexinglanguage; + +import com.yahoo.language.process.Generator; +import org.junit.Test; + +import java.util.Map; + +public class GeneratorScriptTestCase { + + @Test + public void testGenerate() { + // No generators - parsing only + var tester = new GeneratorScriptTester(Generator.throwsOnUse.asMap()); + tester.expressionFrom("input myText | generate | attribute 'myGeneratedText'"); + + // One generator + tester = new GeneratorScriptTester(Map.of( + "gen1", new GeneratorScriptTester.RepeatMockGenerator("myDocument.myGeneratedText"))); + tester.testStatement("input myText | generate | attribute myGeneratedText", + "hello", "hello hello"); + tester.testStatement("input myText | generate gen1 | attribute myGeneratedText", + "hello", "hello hello"); + tester.testStatement("input myText | generate 'gen1' | attribute 'myGeneratedText'", + "hello", "hello hello"); + tester.testStatement("input myText | generate 'gen1' | attribute myGeneratedText", + null, null); + + // Two generators + tester = new GeneratorScriptTester(Map.of( + "gen1", new GeneratorScriptTester.RepeatMockGenerator("myDocument.myGeneratedText", 2), + "gen2", new GeneratorScriptTester.RepeatMockGenerator("myDocument.myGeneratedText", 3))); + tester.testStatement("input myText | generate gen1 | attribute myGeneratedText", + "hello", "hello hello"); + tester.testStatement("input myText | generate gen2 | attribute myGeneratedText", + "hello", "hello hello hello"); + tester.testStatementThrows("input myText | generate | attribute myGeneratedText", + "hello", "Multiple generators are provided but no generator id is given. Valid generators are gen1, gen2"); + tester.testStatementThrows("input myText | generate gen3 | attribute myGeneratedText", + "hello", "Can't find generator 'gen3'. Valid generators are gen1, gen2"); + } +} \ No newline at end of file diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java new file mode 100644 index 000000000000..7bbc900ac4dd --- /dev/null +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/GeneratorScriptTester.java @@ -0,0 +1,99 @@ +package com.yahoo.vespa.indexinglanguage; + +import ai.vespa.llm.completion.Prompt; +import com.yahoo.language.process.Generator; +import com.yahoo.language.simple.SimpleLinguistics; +import com.yahoo.vespa.indexinglanguage.expressions.Expression; +import com.yahoo.vespa.indexinglanguage.parser.ParseException; +import com.yahoo.document.DataType; +import com.yahoo.document.DocumentType; +import com.yahoo.document.Field; +import com.yahoo.document.datatypes.StringFieldValue; +import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class GeneratorScriptTester { + private final Map generators; + + public GeneratorScriptTester(Map generators) { + this.generators = generators; + } + + public void testStatement(String expressionString, String input, String expected) { + var expression = expressionFrom(expressionString); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("myText", DataType.STRING)); + var generatedField = new Field("myGeneratedText", DataType.STRING); + adapter.createField(generatedField); + + if (input != null) + adapter.setValue("myText", new StringFieldValue(input)); + + expression.setStatementOutput(new DocumentType("myDocument"), generatedField); + + ExecutionContext context = new ExecutionContext(adapter); + expression.execute(context); + + if (input == null) { + assertFalse(adapter.values.containsKey("myGeneratedText")); + } + else { + assertTrue(adapter.values.containsKey("myGeneratedText")); + assertEquals(expected, ((StringFieldValue)adapter.values.get("myGeneratedText")).getString()); + } + } + + public void testStatementThrows(String expressionString, String input, String expectedMessage) { + try { + testStatement(expressionString, input, null); + fail(); + } catch (IllegalStateException e) { + assertEquals(expectedMessage, e.getMessage()); + } + } + + public Expression expressionFrom(String string) { + try { + return Expression.fromString(string, new SimpleLinguistics(), Map.of(), generators); + } + catch (ParseException e) { + throw new RuntimeException(e); + } + } + + public static class RepeatMockGenerator implements Generator { + final String expectedDestination; + final int repetitions; + + public RepeatMockGenerator(String expectedDestination) { + this(expectedDestination, 2); + } + + public RepeatMockGenerator(String expectedDestination, int repetitions) { + this.expectedDestination = expectedDestination; + this.repetitions = repetitions; + } + + public String generate(Prompt prompt, Context context) { + var stringBuilder = new StringBuilder(); + + for (int i = 0; i < repetitions; i++) { + stringBuilder.append(prompt); + stringBuilder.append(" "); + } + + return stringBuilder.toString().trim(); + } + + void verifyDestination(Generator.Context context) { + assertEquals(expectedDestination, context.getDestination()); + } + } +} diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java index ac95c72a64bd..9f0e617dd40b 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptParserTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.indexinglanguage; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.expressions.EchoExpression; import com.yahoo.vespa.indexinglanguage.expressions.InputExpression; @@ -96,7 +97,9 @@ private static void assertException(ParseException e, String expectedMessage) th } private static ScriptParserContext newContext(String input) { - return new ScriptParserContext(new SimpleLinguistics(), Embedder.throwsOnUse.asMap()).setInputStream(new IndexingInput(input)); + return new ScriptParserContext( + new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap() + ).setInputStream(new IndexingInput(input)); } } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java index 7a92d51fda39..e13db73f7d7d 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/DefaultFieldNameTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.indexinglanguage.parser; import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Generator; import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.ScriptParserContext; import com.yahoo.vespa.indexinglanguage.expressions.Expression; @@ -18,10 +19,9 @@ public class DefaultFieldNameTestCase { @Test public void requireThatDefaultFieldNameIsAppliedWhenArgumentIsMissing() throws ParseException { IndexingInput input = new IndexingInput("input"); - InputExpression exp = (InputExpression)Expression.newInstance(new ScriptParserContext(new SimpleLinguistics(), - Embedder.throwsOnUse.asMap()) - .setInputStream(input) - .setDefaultFieldName("foo")); + InputExpression exp = (InputExpression) Expression.newInstance(new ScriptParserContext( + new SimpleLinguistics(), Embedder.throwsOnUse.asMap(), Generator.throwsOnUse.asMap() + ).setInputStream(input).setDefaultFieldName("foo")); assertEquals("foo", exp.getFieldName()); } diff --git a/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc b/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc index 9237ff5cd143..ab647242fc17 100644 --- a/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc +++ b/integration/schema-language-server/language-server/src/main/ccc/indexinglanguage/IndexingParser.ccc @@ -33,6 +33,7 @@ INJECT IndexingParser: import com.yahoo.vespa.indexinglanguage.expressions.*; import com.yahoo.vespa.indexinglanguage.linguistics.AnnotatorConfig; import com.yahoo.language.process.Embedder; + import com.yahoo.language.process.Generator; import com.yahoo.language.Linguistics; { /** @@ -43,6 +44,7 @@ INJECT IndexingParser: private String defaultFieldName; private Linguistics linguistics; private Map embedders; + private Map generators; private AnnotatorConfig annotatorCfg; private PrintStream logger = new PrintStream( @@ -73,6 +75,11 @@ INJECT IndexingParser: this.embedders = embedders; return this; } + + public IndexingParser setGenerators(Map generators) { + this.generators = generators; + return this; + } public IndexingParser setAnnotatorConfig(AnnotatorConfig cfg) { annotatorCfg = cfg; @@ -183,6 +190,7 @@ TOKEN : | | | + | | | | @@ -340,6 +348,7 @@ Expression value() : val = exactExp() | val = flattenExp() | val = forEachExp() | + val = generateExp() | val = getFieldExp() | val = getVarExp() | val = guardExp() | @@ -478,6 +487,23 @@ Expression forEachExp() : { return new ForEachExpression(val); } ; +Expression generateExp() : +{ + String generatorId = ""; + String generatorArgument; + List generatorArguments = new ArrayList(); +} + + ( + [ SCAN((identifierStr)+) => (generatorId = identifierStr()) ] + ( SCAN((identifierStr)+) => (generatorArgument = identifierStr()) { generatorArguments.add(generatorArgument); } )* + ) + { + // return new GenerateExpression(linguistics, generators, generatorId, generatorArguments); + return null; + } +; + Expression getFieldExp() : { String val; @@ -869,6 +895,7 @@ String identifierStr() : | | | + | | | | diff --git a/linguistics/abi-spec.json b/linguistics/abi-spec.json index ba5bc3c79702..c37b52763c1c 100644 --- a/linguistics/abi-spec.json +++ b/linguistics/abi-spec.json @@ -403,6 +403,61 @@ "public static final com.yahoo.language.process.Embedder throwsOnUse" ] }, + "com.yahoo.language.process.Generator$Context" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void (java.lang.String)", + "public void (java.lang.String, java.util.Map)", + "public com.yahoo.language.process.Generator$Context copy()", + "public com.yahoo.language.Language getLanguage()", + "public com.yahoo.language.process.Generator$Context setLanguage(com.yahoo.language.Language)", + "public java.lang.String getDestination()", + "public com.yahoo.language.process.Generator$Context setDestination(java.lang.String)", + "public java.lang.String getGeneratorId()", + "public com.yahoo.language.process.Generator$Context setGeneratorId(java.lang.String)", + "public void putCachedValue(java.lang.Object, java.lang.Object)", + "public java.lang.Object getCachedValue(java.lang.Object)", + "public java.lang.Object computeCachedValueIfAbsent(java.lang.Object, java.util.function.Supplier)" + ], + "fields" : [ ] + }, + "com.yahoo.language.process.Generator$FailingGenerator" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.language.process.Generator" + ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void ()", + "public void (java.lang.String)", + "public java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" + ], + "fields" : [ ] + }, + "com.yahoo.language.process.Generator" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public", + "interface", + "abstract" + ], + "methods" : [ + "public java.util.Map asMap()", + "public java.util.Map asMap(java.lang.String)", + "public abstract java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" + ], + "fields" : [ + "public static final java.lang.String defaultGeneratorId", + "public static final com.yahoo.language.process.Generator throwsOnUse" + ] + }, "com.yahoo.language.process.GramSplitter$Gram" : { "superClass" : "java.lang.Object", "interfaces" : [ ], diff --git a/linguistics/src/main/java/com/yahoo/language/process/Generator.java b/linguistics/src/main/java/com/yahoo/language/process/Generator.java new file mode 100644 index 000000000000..e551af7937a9 --- /dev/null +++ b/linguistics/src/main/java/com/yahoo/language/process/Generator.java @@ -0,0 +1,129 @@ +package com.yahoo.language.process; + +import ai.vespa.llm.completion.Prompt; +import com.yahoo.collections.LazyMap; +import com.yahoo.language.Language; + +import java.util.Map; +import java.util.Objects; +import java.util.function.Supplier; + +public interface Generator { + + // Name of generator when none is explicitly given + String defaultGeneratorId = "default"; + + // An instance of this which throws IllegalStateException if attempted used + Generator throwsOnUse = new FailingGenerator(); + + // Returns this generator instance as a map with the default generator name + default Map asMap() { + return asMap(defaultGeneratorId); + } + + // Returns this generator instance as a map with the given name + default Map asMap(String name) { + return Map.of(name, this); + } + + String generate(Prompt prompt, Context context); + + + class Context { + private Language language = Language.UNKNOWN; + private String destination; + private String generatorId = "unknown"; + private final Map cache; + + public Context(String destination) { + this(destination, LazyMap.newHashMap()); + } + + /** + * @param destination the name of the recipient of the generated output + * @param cache a cache shared between all generate invocations for a single request + */ + public Context(String destination, Map cache) { + this.destination = destination; + this.cache = Objects.requireNonNull(cache); + } + + private Context(Context other) { + language = other.language; + destination = other.destination; + generatorId = other.generatorId; + this.cache = other.cache; + } + + public Generator.Context copy() { return new Context(this); } + + /** Returns the language of the text, or UNKNOWN (default) to use a language independent generation */ + public Language getLanguage() { return language; } + + /** Sets the language of the text, or UNKNOWN to use language independent generation */ + public Context setLanguage(Language language) { + this.language = language != null ? language : Language.UNKNOWN; + return this; + } + + /** + * Returns the name of the recipient of this tensor. + * This is either a query feature name + * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field"). + * This cannot be null. + */ + public String getDestination() { return destination; } + + /** + * Sets the name of the recipient of this tensor. + * This is either a query feature name + * ("query(feature)"), or a schema and field name concatenated by a dot ("schema.field"). + */ + public Context setDestination(String destination) { + this.destination = destination; + return this; + } + + /** Return the generator id or 'unknown' if not set */ + public String getGeneratorId() { return generatorId; } + + /** Sets the generator id */ + public Context setGeneratorId(String generatorId) { + this.generatorId = generatorId; + return this; + } + + public void putCachedValue(Object key, Object value) { + cache.put(key, value); + } + + /** Returns a cached value, or null if not present. */ + public Object getCachedValue(Object key) { + return cache.get(key); + } + + /** Returns the cached value, or computes and caches it if not present. */ + @SuppressWarnings("unchecked") + public T computeCachedValueIfAbsent(Object key, Supplier supplier) { + return (T) cache.computeIfAbsent(key, __ -> supplier.get()); + } + + } + + class FailingGenerator implements Generator { + private final String message; + + public FailingGenerator() { + this("No generator has been configured"); + } + + public FailingGenerator(String message) { + this.message = message; + } + + public String generate(Prompt prompt, Context context) { + throw new IllegalStateException(message); + } + } + +} diff --git a/model-integration/README b/model-integration/README index a58d88dc3112..ef9e367d13c4 100644 --- a/model-integration/README +++ b/model-integration/README @@ -1,4 +1,4 @@ -3rd party ML models and converters from these to ranking expresssions, provided as a separate bundle. +3rd party ML models and converters from these to ranking expressions, provided as a separate bundle. This has two purposes - Make converters (importers) available to config models while loading them in just a single instance even when diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index b1277c8bf424..06d634f0a464 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -2,7 +2,8 @@ "ai.vespa.llm.clients.ConfigurableLanguageModel" : { "superClass" : "java.lang.Object", "interfaces" : [ - "ai.vespa.llm.LanguageModel" + "ai.vespa.llm.LanguageModel", + "com.yahoo.language.process.Generator" ], "attributes" : [ "public", @@ -14,7 +15,8 @@ "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", "protected void setApiKey(ai.vespa.llm.InferenceParameters)", "protected java.lang.String getEndpoint()", - "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" + "protected void setEndpoint(ai.vespa.llm.InferenceParameters)", + "public java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" ], "fields" : [ ] }, @@ -157,7 +159,8 @@ "ai.vespa.llm.clients.LocalLLM" : { "superClass" : "com.yahoo.component.AbstractComponent", "interfaces" : [ - "ai.vespa.llm.LanguageModel" + "ai.vespa.llm.LanguageModel", + "com.yahoo.language.process.Generator" ], "attributes" : [ "public" @@ -166,7 +169,8 @@ "public void (ai.vespa.llm.clients.LlmLocalClientConfig)", "public void deconstruct()", "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)", + "public java.lang.String generate(ai.vespa.llm.completion.Prompt, com.yahoo.language.process.Generator$Context)" ], "fields" : [ ] }, diff --git a/model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java b/model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java new file mode 100644 index 000000000000..15884415b388 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/generative/GeneratorUtils.java @@ -0,0 +1,58 @@ +package ai.vespa.generative; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.component.ComponentId; +import com.yahoo.component.provider.ComponentRegistry; + +import java.util.logging.Logger; +import java.util.stream.Collectors; + + +/** + * Provide utilities to implement Generator interface. + * It is used by language models as well as other generators. + */ +public class GeneratorUtils { + public static LanguageModel findLanguageModel(String providerId, ComponentRegistry languageModels, Logger log) + throws IllegalArgumentException + { + if (languageModels.allComponents().isEmpty()) { + throw new IllegalArgumentException("No language models were found"); + } + + if (providerId == null || providerId.isEmpty()) { + var entry = languageModels.allComponentsById().entrySet().stream().findFirst(); + + if (entry.isEmpty()) { + throw new IllegalArgumentException("No language models were found"); // shouldn't happen given check above + } + + log.info("Language model provider was not found in config. " + + "Fallback to using first available language model: " + entry.get().getKey()); + + return entry.get().getValue(); + } + + final LanguageModel languageModel = languageModels.getComponent(providerId); + + if (languageModel == null) { + throw new IllegalArgumentException("No component with id '" + providerId + "' was found. " + + "Available LLM components are: " + languageModels.allComponentsById().keySet().stream() + .map(ComponentId::toString).collect(Collectors.joining(","))); + } + + return languageModel; + } + + public static String generate( + Prompt prompt, LanguageModel languageModel) + { + var options = new InferenceParameters(s -> ""); + var completions = languageModel.complete(prompt, options); + var firstCompletion = completions.get(0); + return firstCompletion.text(); + } +} diff --git a/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java b/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java new file mode 100644 index 000000000000..514c1d8d0179 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/generative/LanguageModelGenerator.java @@ -0,0 +1,30 @@ +package ai.vespa.generative; + +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.document.DataType; + +import java.util.logging.Logger; + +/** + * A generator that uses a language model to generate text. + * Unlike using a language model directly, this is supposed to be extended with configurable parameters, + * e.g. prompt template, postprocessors, etc. + */ +public class LanguageModelGenerator extends AbstractComponent implements com.yahoo.language.process.Generator { + private static final Logger logger = Logger.getLogger(LanguageModelGenerator.class.getName()); + private final LanguageModel languageModel; + + @Inject + public LanguageModelGenerator(LanguageModelGeneratorConfig config, ComponentRegistry languageModels) { + this.languageModel = GeneratorUtils.findLanguageModel(config.providerId(), languageModels, logger); + } + + @Override + public String generate(Prompt prompt, Context context) { + return GeneratorUtils.generate(prompt, languageModel); + } +} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java index 015b9195a258..493ca59ad3db 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -8,6 +8,7 @@ import ai.vespa.secret.Secrets; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; +import com.yahoo.language.process.Generator; import java.util.HashMap; import java.util.logging.Logger; @@ -19,7 +20,7 @@ * @author lesters */ @Beta -public abstract class ConfigurableLanguageModel implements LanguageModel { +public abstract class ConfigurableLanguageModel implements LanguageModel, Generator { private static final Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); @@ -77,4 +78,15 @@ protected void setEndpoint(InferenceParameters params) { } } + @Override + public String generate(Prompt prompt, Context context) { + var params = new HashMap(); + var options = new InferenceParameters(params::get); + setApiKey(options); + + var completions = complete(prompt, options); + var firstCompletion = completions.get(0); + return firstCompletion.text(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java index bbb82db71391..061f3a605d9e 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.clients; +import ai.vespa.generative.GeneratorUtils; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; import ai.vespa.llm.LanguageModelException; @@ -8,6 +9,7 @@ import ai.vespa.llm.completion.Prompt; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; +import com.yahoo.language.process.Generator; import de.kherud.llama.LlamaModel; import de.kherud.llama.ModelParameters; @@ -31,7 +33,7 @@ * * @author lesters */ -public class LocalLLM extends AbstractComponent implements LanguageModel { +public class LocalLLM extends AbstractComponent implements LanguageModel, Generator { private final static Logger logger = Logger.getLogger(LocalLLM.class.getName()); @@ -152,4 +154,8 @@ private String rejectedExecutionReason(String prepend) { } + @Override + public String generate(Prompt prompt, Context context) { + return GeneratorUtils.generate(prompt, this); + } } diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java index 28af52babdf2..a9b2f419e945 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.clients; +import ai.vespa.generative.GeneratorUtils; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.client.openai.OpenAiClient; import ai.vespa.llm.completion.Completion; @@ -8,6 +9,7 @@ import ai.vespa.secret.Secrets; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; +import com.yahoo.language.process.Generator; import java.util.List; import java.util.concurrent.CompletableFuture; diff --git a/model-integration/src/main/resources/configdefinitions/language-model-generator.def b/model-integration/src/main/resources/configdefinitions/language-model-generator.def new file mode 100644 index 000000000000..133015d76d2d --- /dev/null +++ b/model-integration/src/main/resources/configdefinitions/language-model-generator.def @@ -0,0 +1,11 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.generative + +# The external LLM provider - the id of a LanguageModel component +providerId string default="" + +# The default prompt to use if not overridden in query +prompt string default="" + +# The default prompt template file to use if not overridden in query. Above prompt has precedence if it is set. +promptTemplate path optional \ No newline at end of file diff --git a/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java b/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java new file mode 100644 index 000000000000..f25cac66989b --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/generative/LanguageModelGeneratorTest.java @@ -0,0 +1,74 @@ +package ai.vespa.generative; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.component.ComponentId; +import com.yahoo.component.provider.ComponentRegistry; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +import com.yahoo.document.DataType; +import org.junit.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +public class LanguageModelGeneratorTest { + @Test + public void testGeneration() { + LanguageModel languageModel1 = new RepeatMockLanguageModel(1); + LanguageModel languageModel2 = new RepeatMockLanguageModel(2); + var languageModels = Map.of("mock1", languageModel1, "mock2", languageModel2); + + var config1 = new LanguageModelGeneratorConfig.Builder().providerId("mock1").build(); + var generator1 = createGenerator(config1, languageModels); + var context = new com.yahoo.language.process.Generator.Context("schema.indexing"); + var result1 = generator1.generate(StringPrompt.from("hello"), context); + assertEquals("hello", result1); + + var config2 = new LanguageModelGeneratorConfig.Builder().providerId("mock2").build(); + var generator2 = createGenerator(config2, Map.of("mock1", languageModel1, "mock2", languageModel2)); + var result2 = generator2.generate(StringPrompt.from("hello"), context); + assertEquals("hello hello", result2); + } + + private static LanguageModelGenerator createGenerator(LanguageModelGeneratorConfig config, Map languageModels) { + ComponentRegistry models = new ComponentRegistry<>(); + languageModels.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new LanguageModelGenerator(config, models); + } + + public static class RepeatMockLanguageModel implements LanguageModel { + private final int repetitions; + + public RepeatMockLanguageModel(int repetitions) { + this.repetitions = repetitions; + } + + @Override + public List complete(Prompt prompt, InferenceParameters params) { + var stringBuilder = new StringBuilder(); + + for (int i = 0; i < repetitions; i++) { + stringBuilder.append(prompt.asString()); + stringBuilder.append(" "); + } + + return List.of(Completion.from(stringBuilder.toString().trim())); + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, + InferenceParameters params, + Consumer consumer) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/quickbuild.sh b/quickbuild.sh new file mode 100755 index 000000000000..810f5e9a8798 --- /dev/null +++ b/quickbuild.sh @@ -0,0 +1,2 @@ +mvn clean install --threads 1C -Dmaven.javadoc.skip=true -Dmaven.source.skip=true -DskipTests "$@" +