Skip to content

Commit

Permalink
Merge pull request #31926 from vespa-engine/bratseth/resolve-dimensio…
Browse files Browse the repository at this point in the history
…n-bindings

Resolve dimension bindings in ranking expressions
  • Loading branch information
bratseth authored Jul 10, 2024
2 parents 8c2d413 + 03e83d3 commit b6a2fcb
Show file tree
Hide file tree
Showing 39 changed files with 234 additions and 74 deletions.
4 changes: 4 additions & 0 deletions config-model/src/main/java/com/yahoo/schema/Application.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class Application {
private final ApplicationPackage applicationPackage;
private final Map<String, Schema> schemas;
private final DocumentModel documentModel;
private final RankProfileRegistry rankProfileRegistry;

public Application(ApplicationPackage applicationPackage,
List<Schema> schemas,
Expand All @@ -41,6 +42,7 @@ public Application(ApplicationPackage applicationPackage,
Set<Class<? extends Processor>> processorsToSkip,
DeployLogger logger) {
this.applicationPackage = applicationPackage;
this.rankProfileRegistry = rankProfileRegistry;

Map<String, Schema> schemaMap = new LinkedHashMap<>();
for (Schema schema : schemas) {
Expand Down Expand Up @@ -87,6 +89,8 @@ public Application(ApplicationPackage applicationPackage,

public ApplicationPackage applicationPackage() { return applicationPackage; }

public RankProfileRegistry rankProfileRegistry() { return rankProfileRegistry; }

/** Returns an unmodifiable list of the schemas of this application */
public Map<String, Schema> schemas() { return schemas; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,29 +135,28 @@ MapEvaluationTypeContext getParent(String forArgument, String boundTo) {
() -> new IllegalArgumentException("argument "+forArgument+" is bound to "+boundTo+" but there is no parent context"));
}

String resolveBinding(String argument) {
String bound = getBinding(argument);
@Override
public String resolveBinding(String name) {
String bound = getBinding(name);
if (bound == null) {
return argument;
return name;
}
return getParent(argument, bound).resolveBinding(bound);
return getParent(name, bound).resolveBinding(bound);
}

private TensorType resolveType(Reference reference) {
if (currentResolutionCallStack.contains(reference))
throw new IllegalArgumentException("Invocation loop: " +
currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) +
" -> " + reference);

// Bound to a function argument?
Optional<String> binding = boundIdentifier(reference);
if (binding.isPresent()) {
try {
// This is not pretty, but changing to bind expressions rather
// than their string values requires deeper changes
var expr = new RankingExpression(binding.get());
var type = expr.type(getParent(reference.name(), binding.get()));
return type;
return expr.type(getParent(reference.name(), binding.get()));
} catch (ParseException e) {
throw new IllegalArgumentException(e);
}
Expand All @@ -180,8 +179,7 @@ private TensorType resolveType(Reference reference) {
if (function.isPresent()) {
var body = function.get().getBody();
var child = this.withBindings(bind(function.get().arguments(), reference.arguments()));
var type = body.type(child);
return type;
return body.type(child);
}

// A reference to an ONNX model?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,6 @@ private RankingExpressionFunction compile(RankingExpressionFunction function,
Map<String, RankingExpressionFunction> inlineFunctions,
ExpressionTransforms expressionTransforms) {
if (function == null) return null;

RankProfileTransformContext context = new RankProfileTransformContext(this,
queryProfiles,
featureTypes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ private void deriveFunctionProperties(Map<String, RankProfile.RankingExpressionF
for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) {
String propertyName = RankingExpression.propertyName(e.getKey());
if (! context.serializedFunctions().containsKey(propertyName)) {

String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString();
context.addFunctionSerialization(propertyName, expressionString);

e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
.forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.document.DataType;
import com.yahoo.schema.derived.DerivedConfiguration;
import com.yahoo.search.query.profile.QueryProfile;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.search.query.profile.types.FieldDescription;
Expand Down Expand Up @@ -417,6 +418,69 @@ void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseExcepti
assertQueryFeatureTypeSettings(registry.get(schema, "p2"), schema);
}

@Test
void dimensionArgumentResolution() throws ParseException{
RankProfileRegistry registry = new RankProfileRegistry();
ApplicationBuilder builder = new ApplicationBuilder(registry);
builder.addSchema("""
schema test {
document test {
field embeddings type tensor(d1[384]) {
indexing: attribute
}
}
rank-profile feature_logging {
inputs {
query(query_embedding_int8) tensor<int8>(d0[384])
query(query_embedding) tensor<bfloat16>(d0{}, d1[384])
}
first-phase {
expression: fakeRankResult
}
function query_field_cosine_similarity(field_name, query_tensor, dimension) {
expression: cosine_similarity(attribute(field_name), query_tensor, dimension)
}
function query_field_cos_distances(field_name, query_tensor, dimension){
expression: max(1 - query_field_cosine_similarity(field_name, query_tensor, dimension), 0.0)
}
function query_field_acos_distances(field_name, query_tensor, dimension) {
expression: acos(query_field_cosine_similarity(field_name, query_tensor, dimension))
}
function query_field_closeness(field_name, query_tensor, dimension) {
expression: reduce(1/(1+query_field_acos_distances(field_name, query_tensor, dimension)), max)
}
summary-features {
query_field_closeness(embeddings, query(query_embedding), d1)
}
}
}""");
Application application = builder.build(true);
RankProfile profile = application.rankProfileRegistry().get("test", "feature_logging");

// Rank profile content is unbound, as written:
assertEquals("join(reduce(join(attribute(field_name), query_tensor, f(a,b)(a * b)), sum, dimension), " +
"map(join(reduce(join(attribute(field_name), attribute(field_name), f(a,b)(a * b)), sum, dimension), " +
"reduce(join(query_tensor, query_tensor, f(a,b)(a * b)), sum, dimension), " +
"f(a,b)(a * b)), f(a)(sqrt(a))), f(a,b)(a / b))",
profile.findFunction("query_field_cosine_similarity").function().getBody().getRoot().toString());

// Derived rank profile content is bound: attribute(field_name) -> attribute(embeddings), dimension -> d1
assertEquals("join(reduce(join(attribute(embeddings), query(query_embedding), f(a,b)(a * b)), sum, d1), " +
"map(join(reduce(join(attribute(embeddings), attribute(embeddings), f(a,b)(a * b)), sum, d1), " +
"reduce(join(query(query_embedding), query(query_embedding), f(a,b)(a * b)), sum, d1), " +
"f(a,b)(a * b)), f(a)(sqrt(a))), f(a,b)(a / b))",
findDerivedFunction(application, "feature_logging", "query_field_cosine_similarity"));
}

private String findDerivedFunction(Application application, String rankProfileName, String functionName) {
var derived = new DerivedConfiguration(application.schemas().get("test"), application.rankProfileRegistry());
for (var line : derived.getRankProfileList().getRankProfiles().get("feature_logging").configProperties()) {
if (line.getFirst().startsWith("rankingExpression(query_field_cosine_similarity@"))
return line.getSecond();
}
return null;
}

private static QueryProfileRegistry setupQueryProfileTypes() {
QueryProfileRegistry registry = new QueryProfileRegistry();
QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry();
Expand Down
1 change: 1 addition & 0 deletions model-evaluation/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"public com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)",
"public double getDouble(int)",
"public int getIndex(java.lang.String)",
"public java.lang.String resolveBinding(java.lang.String)",
"public int size()",
"public java.util.Set names()",
"public java.util.Set arguments()",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ public int getIndex(String name) {
return requireIndexOf(name);
}

@Override
public String resolveBinding(String argument) {
return null;
}

@Override
public int size() {
return indexedBindings.names().size();
Expand Down
6 changes: 5 additions & 1 deletion searchlib/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@
"public"
],
"methods" : [
"public void <init>(com.yahoo.searchlib.rankingexpression.ExpressionFunction, java.lang.String, java.lang.String)",
"public void <init>(java.lang.String, java.lang.String)",
"public java.lang.String getName()",
"public java.lang.String getExpressionString()"
],
Expand Down Expand Up @@ -424,6 +424,7 @@
"public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)",
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)",
"public final double getDouble(int)",
"public java.lang.String resolveBinding(java.lang.String)",
"public com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext clone()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()",
"public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)",
Expand Down Expand Up @@ -537,6 +538,7 @@
"public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)",
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)",
"public java.lang.String resolveBinding(java.lang.String)",
"public com.yahoo.searchlib.rankingexpression.evaluation.DoubleOnlyArrayContext clone()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()",
"public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)",
Expand Down Expand Up @@ -629,6 +631,7 @@
"public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)",
"public void put(java.lang.String, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public java.lang.String resolveBinding(java.lang.String)",
"public java.util.Map bindings()",
"public com.yahoo.searchlib.rankingexpression.evaluation.MapContext thawedCopy()",
"public java.util.Set names()",
Expand All @@ -651,6 +654,7 @@
"public void setType(com.yahoo.searchlib.rankingexpression.Reference, com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.TensorType getType(java.lang.String)",
"public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)",
"public java.lang.String resolveBinding(java.lang.String)",
"public java.util.Map bindings()",
"public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)"
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ public String toString() {
* An instance of a serialization of this function, using a particular serialization context (by {@link
* ExpressionFunction#expand})
*/
public class Instance {
public static class Instance {

private final String name;
private final String expressionString;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ public final double getDouble(int index) {
return value;
}

@Override
public String resolveBinding(String argument) {
return null;
}

/**
* Creates a clone of this context suitable for evaluating against the same ranking expression
* in a different thread (i.e, name name to index map, different value set.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ public final Value get(int index) {
return new DoubleValue(getDouble(index));
}

@Override
public String resolveBinding(String argument) {
return null;
}

/**
* Creates a clone of this context suitable for evaluating against the same ranking expression
* in a different thread (i.e, name name to index map, different value set.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ public void put(String key, Value value) {
bindings.put(key, value.freeze());
}

@Override
public String resolveBinding(String argument) {
return null;
}

/** Returns an immutable view of the bindings of this. */
public Map<String, Value> bindings() {
if (frozen) return bindings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public TensorType getType(Reference reference) {
return featureTypes.get(reference);
}

@Override
public String resolveBinding(String argument) {
return null;
}

/** Returns an unmodifiable map of the bindings in this */
public Map<Reference, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ public StringBuilder toString(StringBuilder string, SerializationContext context
if ( needSerialization ) {
ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path);
functionName = instance.getName();

context.addFunctionSerialization(RankingExpression.propertyName(functionName), instance.getExpressionString());
for (Map.Entry<String, TensorType> argumentType : function.argumentTypes().entrySet())
context.addArgumentTypeSerialization(functionName, argumentType.getKey(), argumentType.getValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ private ExpressionNode toExpressionNode(TensorFunction<Reference> f) {
}

private static ScalarFunction<Reference> transform(ScalarFunction<Reference> input,
Function<ExpressionNode, ExpressionNode> transformer)
{
Function<ExpressionNode, ExpressionNode> transformer) {
if (input instanceof ExpressionScalarFunction wrapper) {
ExpressionNode transformed = transformer.apply(wrapper.expression);
return new ExpressionScalarFunction(transformed);
Expand Down Expand Up @@ -411,6 +410,12 @@ public Value get(String name) {
public TensorType getType(Reference name) {
return delegate.getType(name);
}

@Override
public String resolveBinding(String argument) {
return delegate.resolveBinding(argument);
}

}

private static Context asContext(EvaluationContext<Reference> generic) {
Expand Down
10 changes: 7 additions & 3 deletions vespajlib/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -1570,7 +1570,8 @@
"public void put(java.lang.String, com.yahoo.tensor.Tensor)",
"public com.yahoo.tensor.TensorType getType(java.lang.String)",
"public com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)",
"public com.yahoo.tensor.Tensor getTensor(java.lang.String)"
"public com.yahoo.tensor.Tensor getTensor(java.lang.String)",
"public java.lang.String resolveBinding(java.lang.String)"
],
"fields" : [ ]
},
Expand Down Expand Up @@ -1599,7 +1600,8 @@
],
"methods" : [
"public abstract com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)",
"public abstract com.yahoo.tensor.TensorType getType(java.lang.String)"
"public abstract com.yahoo.tensor.TensorType getType(java.lang.String)",
"public abstract java.lang.String resolveBinding(java.lang.String)"
],
"fields" : [ ]
},
Expand Down Expand Up @@ -1685,7 +1687,7 @@
],
"methods" : [
"public void <init>()",
"public final com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)"
],
"fields" : [ ]
Expand Down Expand Up @@ -1809,6 +1811,7 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public final com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public int hashCode()"
],
Expand Down Expand Up @@ -2920,6 +2923,7 @@
"methods" : [
"public static com.yahoo.tensor.functions.ToStringContext empty()",
"public abstract java.lang.String getBinding(java.lang.String)",
"public java.lang.String resolveBinding(java.lang.String)",
"public java.util.Optional typeContext()",
"public abstract com.yahoo.tensor.functions.ToStringContext parent()"
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ public TensorType getType(String name) {
}

@Override
public TensorType getType(NAMETYPE name) {
return getType(name.name());
}
public TensorType getType(NAMETYPE name) { return getType(name.name()); }

@Override
public Tensor getTensor(String name) { return bindings.get(name); }

@Override
public String resolveBinding(String name) { return name; }

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@ public interface TypeContext<NAMETYPE extends Name> {
*/
TensorType getType(String name);

/** Returns the string a parameter is bound to, or the input name if none. */
String resolveBinding(String name);

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {

@Override
public String toString(ToStringContext<NAMETYPE> context) {
return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
return "argmax(" + argument.toString(context) + Reduce.commaSeparatedNames(dimensions, context) + ")";
}

@Override
Expand Down
Loading

0 comments on commit b6a2fcb

Please sign in to comment.