From 03e83d37219f00b8dfdcd7269ecc07b9a794fc91 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 10 Jul 2024 12:50:33 +0200 Subject: [PATCH] Resolve dimension bindings in ranking expressions People can parametrize the names of features in ranking expression functions, and then naturally expects to be able to do the same with tensor dimension names. This adds support for that in type resolution and expression generation, which is enough for supporting this in rank profiles. --- .../java/com/yahoo/schema/Application.java | 4 ++ .../schema/MapEvaluationTypeContext.java | 16 ++--- .../java/com/yahoo/schema/RankProfile.java | 1 - .../yahoo/schema/derived/RawRankProfile.java | 2 +- .../com/yahoo/schema/RankProfileTestCase.java | 64 +++++++++++++++++++ model-evaluation/abi-spec.json | 1 + .../models/evaluation/LazyArrayContext.java | 5 ++ searchlib/abi-spec.json | 6 +- .../rankingexpression/ExpressionFunction.java | 2 +- .../evaluation/ArrayContext.java | 5 ++ .../evaluation/DoubleOnlyArrayContext.java | 5 ++ .../evaluation/MapContext.java | 5 ++ .../evaluation/MapTypeContext.java | 5 ++ .../rankingexpression/rule/ReferenceNode.java | 1 - .../rule/TensorFunctionNode.java | 9 ++- vespajlib/abi-spec.json | 10 ++- .../evaluation/MapEvaluationContext.java | 7 +- .../yahoo/tensor/evaluation/TypeContext.java | 3 + .../com/yahoo/tensor/functions/Argmax.java | 2 +- .../com/yahoo/tensor/functions/Argmin.java | 2 +- .../com/yahoo/tensor/functions/CellCast.java | 13 ++-- .../functions/CompositeTensorFunction.java | 2 +- .../com/yahoo/tensor/functions/Concat.java | 5 +- .../tensor/functions/CosineSimilarity.java | 15 +++-- .../functions/DenseSubspaceFunction.java | 1 + .../tensor/functions/EuclideanDistance.java | 9 +-- .../com/yahoo/tensor/functions/Expand.java | 22 +++++-- .../com/yahoo/tensor/functions/Generate.java | 5 ++ .../yahoo/tensor/functions/L1Normalize.java | 2 +- .../yahoo/tensor/functions/L2Normalize.java | 2 +- .../com/yahoo/tensor/functions/Matmul.java | 2 +- .../com/yahoo/tensor/functions/Reduce.java | 9 +-- .../yahoo/tensor/functions/ReduceJoin.java | 8 +-- .../com/yahoo/tensor/functions/Rename.java | 19 +++--- .../com/yahoo/tensor/functions/Softmax.java | 2 +- .../tensor/functions/ToStringContext.java | 6 ++ .../com/yahoo/tensor/functions/XwPlusB.java | 2 +- .../functions/CosineSimilarityTestCase.java | 28 ++++++++ .../functions/EuclideanDistanceTestCase.java | 1 + 39 files changed, 234 insertions(+), 74 deletions(-) diff --git a/config-model/src/main/java/com/yahoo/schema/Application.java b/config-model/src/main/java/com/yahoo/schema/Application.java index dbc21743d967..7142b0c5a2d7 100644 --- a/config-model/src/main/java/com/yahoo/schema/Application.java +++ b/config-model/src/main/java/com/yahoo/schema/Application.java @@ -30,6 +30,7 @@ public class Application { private final ApplicationPackage applicationPackage; private final Map schemas; private final DocumentModel documentModel; + private final RankProfileRegistry rankProfileRegistry; public Application(ApplicationPackage applicationPackage, List schemas, @@ -41,6 +42,7 @@ public Application(ApplicationPackage applicationPackage, Set> processorsToSkip, DeployLogger logger) { this.applicationPackage = applicationPackage; + this.rankProfileRegistry = rankProfileRegistry; Map schemaMap = new LinkedHashMap<>(); for (Schema schema : schemas) { @@ -87,6 +89,8 @@ public Application(ApplicationPackage applicationPackage, public ApplicationPackage applicationPackage() { return applicationPackage; } + public RankProfileRegistry rankProfileRegistry() { return rankProfileRegistry; } + /** Returns an unmodifiable list of the schemas of this application */ public Map schemas() { return schemas; } diff --git a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java index 2a8dd49a0c1a..f3e8c0a2f489 100644 --- a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java @@ -135,12 +135,13 @@ MapEvaluationTypeContext getParent(String forArgument, String boundTo) { () -> new IllegalArgumentException("argument "+forArgument+" is bound to "+boundTo+" but there is no parent context")); } - String resolveBinding(String argument) { - String bound = getBinding(argument); + @Override + public String resolveBinding(String name) { + String bound = getBinding(name); if (bound == null) { - return argument; + return name; } - return getParent(argument, bound).resolveBinding(bound); + return getParent(name, bound).resolveBinding(bound); } private TensorType resolveType(Reference reference) { @@ -148,7 +149,6 @@ private TensorType resolveType(Reference reference) { throw new IllegalArgumentException("Invocation loop: " + currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + " -> " + reference); - // Bound to a function argument? Optional binding = boundIdentifier(reference); if (binding.isPresent()) { @@ -156,8 +156,7 @@ private TensorType resolveType(Reference reference) { // This is not pretty, but changing to bind expressions rather // than their string values requires deeper changes var expr = new RankingExpression(binding.get()); - var type = expr.type(getParent(reference.name(), binding.get())); - return type; + return expr.type(getParent(reference.name(), binding.get())); } catch (ParseException e) { throw new IllegalArgumentException(e); } @@ -180,8 +179,7 @@ private TensorType resolveType(Reference reference) { if (function.isPresent()) { var body = function.get().getBody(); var child = this.withBindings(bind(function.get().arguments(), reference.arguments())); - var type = body.type(child); - return type; + return body.type(child); } // A reference to an ONNX model? diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index ed1a4e98b49b..0d812699f887 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1184,7 +1184,6 @@ private RankingExpressionFunction compile(RankingExpressionFunction function, Map inlineFunctions, ExpressionTransforms expressionTransforms) { if (function == null) return null; - RankProfileTransformContext context = new RankProfileTransformContext(this, queryProfiles, featureTypes, diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java index 15e5891a3e34..d2f6f3b18e7f 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -288,9 +288,9 @@ private void deriveFunctionProperties(Map e : functions.entrySet()) { String propertyName = RankingExpression.propertyName(e.getKey()); if (! context.serializedFunctions().containsKey(propertyName)) { - String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString(); context.addFunctionSerialization(propertyName, expressionString); + e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue())); } diff --git a/config-model/src/test/java/com/yahoo/schema/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankProfileTestCase.java index 564bfd4a9904..f07ed4cf42ef 100644 --- a/config-model/src/test/java/com/yahoo/schema/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/RankProfileTestCase.java @@ -8,6 +8,7 @@ import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.document.DataType; +import com.yahoo.schema.derived.DerivedConfiguration; import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.query.profile.types.FieldDescription; @@ -417,6 +418,69 @@ void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseExcepti assertQueryFeatureTypeSettings(registry.get(schema, "p2"), schema); } + @Test + void dimensionArgumentResolution() throws ParseException{ + RankProfileRegistry registry = new RankProfileRegistry(); + ApplicationBuilder builder = new ApplicationBuilder(registry); + builder.addSchema(""" +schema test { +document test { + field embeddings type tensor(d1[384]) { + indexing: attribute + } +} +rank-profile feature_logging { + inputs { + query(query_embedding_int8) tensor(d0[384]) + query(query_embedding) tensor(d0{}, d1[384]) + } + first-phase { + expression: fakeRankResult + } + function query_field_cosine_similarity(field_name, query_tensor, dimension) { + expression: cosine_similarity(attribute(field_name), query_tensor, dimension) + } + function query_field_cos_distances(field_name, query_tensor, dimension){ + expression: max(1 - query_field_cosine_similarity(field_name, query_tensor, dimension), 0.0) + } + function query_field_acos_distances(field_name, query_tensor, dimension) { + expression: acos(query_field_cosine_similarity(field_name, query_tensor, dimension)) + } + function query_field_closeness(field_name, query_tensor, dimension) { + expression: reduce(1/(1+query_field_acos_distances(field_name, query_tensor, dimension)), max) + } + summary-features { + query_field_closeness(embeddings, query(query_embedding), d1) + } +} +}"""); + Application application = builder.build(true); + RankProfile profile = application.rankProfileRegistry().get("test", "feature_logging"); + + // Rank profile content is unbound, as written: + assertEquals("join(reduce(join(attribute(field_name), query_tensor, f(a,b)(a * b)), sum, dimension), " + + "map(join(reduce(join(attribute(field_name), attribute(field_name), f(a,b)(a * b)), sum, dimension), " + + "reduce(join(query_tensor, query_tensor, f(a,b)(a * b)), sum, dimension), " + + "f(a,b)(a * b)), f(a)(sqrt(a))), f(a,b)(a / b))", + profile.findFunction("query_field_cosine_similarity").function().getBody().getRoot().toString()); + + // Derived rank profile content is bound: attribute(field_name) -> attribute(embeddings), dimension -> d1 + assertEquals("join(reduce(join(attribute(embeddings), query(query_embedding), f(a,b)(a * b)), sum, d1), " + + "map(join(reduce(join(attribute(embeddings), attribute(embeddings), f(a,b)(a * b)), sum, d1), " + + "reduce(join(query(query_embedding), query(query_embedding), f(a,b)(a * b)), sum, d1), " + + "f(a,b)(a * b)), f(a)(sqrt(a))), f(a,b)(a / b))", + findDerivedFunction(application, "feature_logging", "query_field_cosine_similarity")); + } + + private String findDerivedFunction(Application application, String rankProfileName, String functionName) { + var derived = new DerivedConfiguration(application.schemas().get("test"), application.rankProfileRegistry()); + for (var line : derived.getRankProfileList().getRankProfiles().get("feature_logging").configProperties()) { + if (line.getFirst().startsWith("rankingExpression(query_field_cosine_similarity@")) + return line.getSecond(); + } + return null; + } + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index 667712d0daae..226bc5e60683 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -36,6 +36,7 @@ "public com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)", "public double getDouble(int)", "public int getIndex(java.lang.String)", + "public java.lang.String resolveBinding(java.lang.String)", "public int size()", "public java.util.Set names()", "public java.util.Set arguments()", diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 18a3f01cf92d..1a57800a9e23 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -112,6 +112,11 @@ public int getIndex(String name) { return requireIndexOf(name); } + @Override + public String resolveBinding(String argument) { + return null; + } + @Override public int size() { return indexedBindings.names().size(); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 2d67abb0e048..7d06db8971c3 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -257,7 +257,7 @@ "public" ], "methods" : [ - "public void (com.yahoo.searchlib.rankingexpression.ExpressionFunction, java.lang.String, java.lang.String)", + "public void (java.lang.String, java.lang.String)", "public java.lang.String getName()", "public java.lang.String getExpressionString()" ], @@ -424,6 +424,7 @@ "public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)", "public final double getDouble(int)", + "public java.lang.String resolveBinding(java.lang.String)", "public com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext clone()", "public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()", "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", @@ -537,6 +538,7 @@ "public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)", + "public java.lang.String resolveBinding(java.lang.String)", "public com.yahoo.searchlib.rankingexpression.evaluation.DoubleOnlyArrayContext clone()", "public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()", "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", @@ -629,6 +631,7 @@ "public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)", "public void put(java.lang.String, com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public java.lang.String resolveBinding(java.lang.String)", "public java.util.Map bindings()", "public com.yahoo.searchlib.rankingexpression.evaluation.MapContext thawedCopy()", "public java.util.Set names()", @@ -651,6 +654,7 @@ "public void setType(com.yahoo.searchlib.rankingexpression.Reference, com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.TensorType getType(java.lang.String)", "public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)", + "public java.lang.String resolveBinding(java.lang.String)", "public java.util.Map bindings()", "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)" ], diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 840eacd9dd9c..fd173bca2267 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -216,7 +216,7 @@ public String toString() { * An instance of a serialization of this function, using a particular serialization context (by {@link * ExpressionFunction#expand}) */ - public class Instance { + public static class Instance { private final String name; private final String expressionString; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index d32b14886cae..bb0faf2a608d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -120,6 +120,11 @@ public final double getDouble(int index) { return value; } + @Override + public String resolveBinding(String argument) { + return null; + } + /** * Creates a clone of this context suitable for evaluating against the same ranking expression * in a different thread (i.e, name name to index map, different value set. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index f0685ea77fd9..0988c58e73d0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -93,6 +93,11 @@ public final Value get(int index) { return new DoubleValue(getDouble(index)); } + @Override + public String resolveBinding(String argument) { + return null; + } + /** * Creates a clone of this context suitable for evaluating against the same ranking expression * in a different thread (i.e, name name to index map, different value set. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 2b620a6d8f05..c8a3eab381f3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -73,6 +73,11 @@ public void put(String key, Value value) { bindings.put(key, value.freeze()); } + @Override + public String resolveBinding(String argument) { + return null; + } + /** Returns an immutable view of the bindings of this. */ public Map bindings() { if (frozen) return bindings; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java index 6c980181b47e..4a723eae578e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java @@ -32,6 +32,11 @@ public TensorType getType(Reference reference) { return featureTypes.get(reference); } + @Override + public String resolveBinding(String argument) { + return null; + } + /** Returns an unmodifiable map of the bindings in this */ public Map bindings() { return Collections.unmodifiableMap(featureTypes); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index b1585233a1e6..54703141cbdb 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -88,7 +88,6 @@ public StringBuilder toString(StringBuilder string, SerializationContext context if ( needSerialization ) { ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); functionName = instance.getName(); - context.addFunctionSerialization(RankingExpression.propertyName(functionName), instance.getExpressionString()); for (Map.Entry argumentType : function.argumentTypes().entrySet()) context.addArgumentTypeSerialization(functionName, argumentType.getKey(), argumentType.getValue()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 202dbebc3116..3c17c7830f2e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -59,8 +59,7 @@ private ExpressionNode toExpressionNode(TensorFunction f) { } private static ScalarFunction transform(ScalarFunction input, - Function transformer) - { + Function transformer) { if (input instanceof ExpressionScalarFunction wrapper) { ExpressionNode transformed = transformer.apply(wrapper.expression); return new ExpressionScalarFunction(transformed); @@ -411,6 +410,12 @@ public Value get(String name) { public TensorType getType(Reference name) { return delegate.getType(name); } + + @Override + public String resolveBinding(String argument) { + return delegate.resolveBinding(argument); + } + } private static Context asContext(EvaluationContext generic) { diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 06cf7d0d71a1..c3bb75b29781 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1570,7 +1570,8 @@ "public void put(java.lang.String, com.yahoo.tensor.Tensor)", "public com.yahoo.tensor.TensorType getType(java.lang.String)", "public com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", - "public com.yahoo.tensor.Tensor getTensor(java.lang.String)" + "public com.yahoo.tensor.Tensor getTensor(java.lang.String)", + "public java.lang.String resolveBinding(java.lang.String)" ], "fields" : [ ] }, @@ -1599,7 +1600,8 @@ ], "methods" : [ "public abstract com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", - "public abstract com.yahoo.tensor.TensorType getType(java.lang.String)" + "public abstract com.yahoo.tensor.TensorType getType(java.lang.String)", + "public abstract java.lang.String resolveBinding(java.lang.String)" ], "fields" : [ ] }, @@ -1685,7 +1687,7 @@ ], "methods" : [ "public void ()", - "public final com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)" ], "fields" : [ ] @@ -1809,6 +1811,7 @@ "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public final com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", "public int hashCode()" ], @@ -2920,6 +2923,7 @@ "methods" : [ "public static com.yahoo.tensor.functions.ToStringContext empty()", "public abstract java.lang.String getBinding(java.lang.String)", + "public java.lang.String resolveBinding(java.lang.String)", "public java.util.Optional typeContext()", "public abstract com.yahoo.tensor.functions.ToStringContext parent()" ], diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index 8cdb06143788..3d7705c42b0a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -23,11 +23,12 @@ public TensorType getType(String name) { } @Override - public TensorType getType(NAMETYPE name) { - return getType(name.name()); - } + public TensorType getType(NAMETYPE name) { return getType(name.name()); } @Override public Tensor getTensor(String name) { return bindings.get(name); } + @Override + public String resolveBinding(String name) { return name; } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index d875f1ef4eb3..eddfb9df276a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -26,4 +26,7 @@ public interface TypeContext { */ TensorType getType(String name); + /** Returns the string a parameter is bound to, or the input name if none. */ + String resolveBinding(String name); + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 88b5a385e9f3..0bd360ef15f0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -47,7 +47,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; + return "argmax(" + argument.toString(context) + Reduce.commaSeparatedNames(dimensions, context) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index ffee606e8f6b..8e1ad71d3848 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -47,7 +47,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; + return "argmin(" + argument.toString(context) + Reduce.commaSeparatedNames(dimensions, context) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index 5655bb020a4f..ff0fe95bc4ed 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -95,14 +95,11 @@ private Tensor castFromSomeFloat(Tensor tensor, TensorType type) { } static private Function selectRestrict(TensorType.Value toValueType) { - switch (toValueType) { - case BFLOAT16: - return val -> Float.intBitsToFloat(Float.floatToRawIntBits(val) & ~0xffff); - case INT8: - return val -> (float)val.byteValue(); - default: - return val -> val; - } + return switch (toValueType) { + case BFLOAT16 -> val -> Float.intBitsToFloat(Float.floatToRawIntBits(val) & ~0xffff); + case INT8 -> val -> (float) val.byteValue(); + default -> val -> val; + }; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 23d90e634884..87b0210cf603 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -17,7 +17,7 @@ public abstract class CompositeTensorFunction extends Ten /** Finds the type this produces by first converting it to a primitive function */ @Override - public final TensorType type(TypeContext context) { + public TensorType type(TypeContext context) { return toPrimitive().type(context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 2635cbecb94b..0b128f77d120 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -62,7 +62,8 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")"; + return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + + ", " + context.resolveBinding(dimension) + ")"; } @Override @@ -70,7 +71,7 @@ public String toString(ToStringContext context) { @Override public TensorType type(TypeContext context) { - return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension); + return TypeResolver.concat(argumentA.type(context), argumentB.type(context), context.resolveBinding(dimension)); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java index d84b1bbdc163..2bdf5266ffcf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.MapEvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.Tensor; @@ -10,10 +11,12 @@ import java.util.List; import java.util.Objects; +import java.util.Optional; /** * Convenience for cosine similarity between vectors. - * cosine_similarity(a, b, mydim) == sum(a*b, mydim) / sqrt(sum(a*a, mydim) * sum(b*b, mydim)) + * cosine_similarity(a, b, mydim) == sum(a*b, mydim) / sqrt(sum(a*a, mydim) * sum(b*b, mydim)). + * * @author arnej */ public class CosineSimilarity extends TensorFunction { @@ -45,18 +48,18 @@ public TensorFunction withArguments(List> arg public TensorType type(TypeContext context) { TensorType t1 = arg1.toPrimitive().type(context); TensorType t2 = arg2.toPrimitive().type(context); - var d1 = t1.dimension(dimension); - var d2 = t2.dimension(dimension); + var resolvedDimension = context.resolveBinding(dimension); + var d1 = t1.dimension(resolvedDimension); + var d2 = t2.dimension(resolvedDimension); if (d1.isEmpty() || d2.isEmpty() || d1.get().type() != Dimension.Type.indexedBound || d2.get().type() != Dimension.Type.indexedBound || ! d1.get().size().equals(d2.get().size())) { throw new IllegalArgumentException("cosine_similarity expects both arguments to have the '" - + dimension + "' dimension with same size, but input types were " + + resolvedDimension + "' dimension with same size, but input types were " + t1 + " and " + t2); } - // Finds the type this produces by first converting it to a primitive function return toPrimitive().type(context); } @@ -83,7 +86,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "cosine_similarity(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + dimension + ")"; + return "cosine_similarity(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java index b6655a153616..3d6e44d86587 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java @@ -37,6 +37,7 @@ class MyTypeContext implements TypeContext { MyTypeContext(TensorType subspaceType) { this.subspaceType = subspaceType; } public TensorType getType(NAMETYPE name) { return getType(name.name()); } public TensorType getType(String name) { return argName.equals(name) ? subspaceType : null; } + public String resolveBinding(String name) { return name; } } TensorType outputType(TensorType subspaceType) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java index d627e0093bff..dc213db4cc0e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java @@ -45,15 +45,16 @@ public TensorFunction withArguments(List> arg public TensorType type(TypeContext context) { TensorType t1 = arg1.toPrimitive().type(context); TensorType t2 = arg2.toPrimitive().type(context); - var d1 = t1.dimension(dimension); - var d2 = t2.dimension(dimension); + String resolvedDimension = context.resolveBinding(dimension); + var d1 = t1.dimension(resolvedDimension); + var d2 = t2.dimension(resolvedDimension); if (d1.isEmpty() || d2.isEmpty() || d1.get().type() != Dimension.Type.indexedBound || d2.get().type() != Dimension.Type.indexedBound || ! d1.get().size().equals(d2.get().size())) { throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '" - + dimension + "' dimension with same size, but input types were " + + resolvedDimension + "' dimension with same size, but input types were " + t1 + " and " + t2); } // Finds the type this produces by first converting it to a primitive function @@ -79,7 +80,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "euclidean_distance(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + dimension + ")"; + return "euclidean_distance(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java index f5a33dde064e..da2295e66db8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java @@ -3,6 +3,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; import java.util.Objects; @@ -16,11 +17,11 @@ public class Expand extends CompositeTensorFunction { private final TensorFunction argument; - private final String dimensionName; + private final String dimension; public Expand(TensorFunction argument, String dimension) { this.argument = argument; - this.dimensionName = dimension; + this.dimension = dimension; } @Override @@ -30,22 +31,31 @@ public Expand(TensorFunction argument, String dimension) { public TensorFunction withArguments(List> arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("Expand must have 1 argument, got " + arguments.size()); - return new Expand<>(arguments.get(0), dimensionName); + return new Expand<>(arguments.get(0), dimension); } @Override public PrimitiveTensorFunction toPrimitive() { - TensorType type = new TensorType.Builder(TensorType.Value.INT8).indexed(dimensionName, 1).build(); + return toPrimitive(dimension); + } + + @Override + public final TensorType type(TypeContext context) { + return toPrimitive(context.resolveBinding(dimension)).type(context); + } + + private PrimitiveTensorFunction toPrimitive(String dimension) { + TensorType type = new TensorType.Builder(TensorType.Value.INT8).indexed(dimension, 1).build(); Generate expansion = new Generate<>(type, ScalarFunctions.constant(1.0)); return new Join<>(expansion, argument, ScalarFunctions.multiply()); } @Override public String toString(ToStringContext context) { - return "expand(" + argument.toString(context) + ", " + dimensionName + ")"; + return "expand(" + argument.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override - public int hashCode() { return Objects.hash("expand", argument, dimensionName); } + public int hashCode() { return Objects.hash("expand", argument, dimension); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 947a39dafb20..fb6963fdbcb8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -183,6 +183,11 @@ public TensorType getType(String name) { return context.getType(name); } + @Override + public String resolveBinding(String name) { + return context.resolveBinding(name); + } + } /** A context which adds the bindings of the generate dimension names to the given context. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index a5afeb6d2a42..b1ea52e880f4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -40,7 +40,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; + return "l1_normalize(" + argument.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index 47e341732ca9..c25871590816 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -42,7 +42,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; + return "l2_normalize(" + argument.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index fbf3b461a353..d97c85d64e14 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -46,7 +46,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; + return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 947fd6e00123..af1f20850851 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -87,19 +87,20 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; + return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparatedNames(dimensions, context) + ")"; } - static String commaSeparated(List list) { + static String commaSeparatedNames(List list, ToStringContext context) { StringBuilder b = new StringBuilder(); for (String element : list) - b.append(", ").append(element); + b.append(", ").append(context.resolveBinding(element)); return b.toString(); } @Override public TensorType type(TypeContext context) { - return outputType(argument.type(context), dimensions); + List resolvedDimensions = dimensions.stream().map(d -> context.resolveBinding(d)).toList(); + return outputType(argument.type(context), resolvedDimensions); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index 2d5a05187471..e6fa448fef3f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -317,10 +317,10 @@ private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) { @Override public String toString(ToStringContext context) { return "reduce_join(" + argumentA.toString(context) + ", " + - argumentB.toString(context) + ", " + - combinator + ", " + - aggregator + - Reduce.commaSeparated(dimensions) + ")"; + argumentB.toString(context) + ", " + + combinator + ", " + + aggregator + + Reduce.commaSeparatedNames(dimensions, context) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 05db61f53956..eabf2e88739e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -14,6 +14,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; /** * The rename tensor function returns a tensor where some dimensions are assigned new names. @@ -71,18 +72,16 @@ public TensorFunction withArguments(List> arg @Override public TensorType type(TypeContext context) { - return type(argument.type(context)); - } - - private TensorType type(TensorType type) { - return TypeResolver.rename(type, fromDimensions, toDimensions); + List resolvedFromDimensions = fromDimensions.stream().map(d -> context.resolveBinding(d)).toList(); + List resolvedToDimensions = toDimensions.stream().map(d -> context.resolveBinding(d)).toList(); + return TypeResolver.rename(argument.type(context), resolvedFromDimensions, resolvedToDimensions); } @Override public Tensor evaluate(EvaluationContext context) { Tensor tensor = argument.evaluate(context); - TensorType renamedType = type(tensor.type()); + TensorType renamedType = TypeResolver.rename(tensor.type(), fromDimensions, toDimensions); // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; @@ -118,12 +117,12 @@ private boolean simpleRenameIsPossible(int[] toIndexes) { return true; } - private String toVectorString(List elements) { + private String toVectorString(List elements, ToStringContext context) { if (elements.size() == 1) - return elements.get(0); + return context.resolveBinding(elements.get(0)); StringBuilder b = new StringBuilder("("); for (String element : elements) - b.append(element).append(", "); + b.append(context.resolveBinding(element)).append(", "); b.setLength(b.length() - 2); b.append(")"); return b.toString(); @@ -132,7 +131,7 @@ private String toVectorString(List elements) { @Override public String toString(ToStringContext context) { return "rename(" + argument.toString(context) + ", " + - toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; + toVectorString(fromDimensions, context) + ", " + toVectorString(toDimensions, context) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index 150bf82f0e89..a0ef87d6e0b7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -46,7 +46,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "softmax(" + argument.toString(context) + ", " + dimension + ")"; + return "softmax(" + argument.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java index eac012a450b2..1faf7e051c35 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java @@ -18,6 +18,12 @@ public interface ToStringContext { /** Returns the name an identifier is bound to, or null if not bound in this context */ String getBinding(String name); + /** Returns the name an identifier is bound to, or the input name if none */ + default String resolveBinding(String name) { + String binding = getBinding(name); + return binding == null ? name : binding; + } + /** * Returns the context used to resolve types in this, if present. * In some functions serialization depends on type information. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 3913a16f35a4..d33d2e678fc0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -48,7 +48,7 @@ public String toString(ToStringContext context) { return "xw_plus_b(" + x.toString(context) + ", " + w.toString(context) + ", " + b.toString(context) + ", " + - dimension + ")"; + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java index 8bacc24c3212..5244cf358cfb 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java @@ -12,6 +12,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -51,8 +52,10 @@ public void testSimilarityInMixed() { static class MyContext implements TypeContext { Map map = new HashMap<>(); + Map bindings = new HashMap<>(); public TensorType getType(Name name) { return getType(name.name()); } public TensorType getType(String name) { return map.get(name); } + public String resolveBinding(String name) { return Optional.ofNullable(bindings.get(name)).orElse(name); } } @Test @@ -80,4 +83,29 @@ public void testExpansion() { assertEquals("tensor(foo{},z[4])", resType.toString()); } + @Test + public void testExpansionWithDimensionBinding() { + var tTypeA = TensorType.fromSpec("tensor(foo{},vecdim[128])"); + var tTypeB = TensorType.fromSpec("tensor(vecdim[128],z[4])"); + var a = new VariableTensor<>("left", tTypeA); + var b = new VariableTensor<>("right", tTypeB); + var op = new CosineSimilarity<>(a, b, "dimensionArgument"); + assertEquals("join(" + + ( "reduce(join(left, right, f(a,b)(a * b)), sum, dimensionArgument), " + + "map(" + + ( "join(" + + ( "reduce(join(left, left, f(a,b)(a * b)), sum, dimensionArgument), " + + "reduce(join(right, right, f(a,b)(a * b)), sum, dimensionArgument), " + + "f(a,b)(a * b)), " ) + + "f(a)(sqrt(a))), " ) + + "f(a,b)(a / b)" ) + + ")", + op.toPrimitive().toString()); + var context = new MyContext(); + context.map.put("left", tTypeA); + context.map.put("right", tTypeB); + context.bindings.put("dimensionArgument", "vecdim"); + var resType = op.type(context); + assertEquals("tensor(foo{},z[4])", resType.toString()); + } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java index 42f9ef33ff1b..f7554a1b6b3e 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java @@ -50,6 +50,7 @@ static class MyContext implements TypeContext { Map map = new HashMap<>(); public TensorType getType(Name name) { return getType(name.name()); } public TensorType getType(String name) { return map.get(name); } + public String resolveBinding(String name) { return name; } } @Test