Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only evaluate ONNX models once in stateless model eval #20154

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions config-model/src/test/cfg/application/stateless_eval/mul.onnx
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
mul.py:f

mul.py:�

input1
input2output"MulmulZ
input2output1"Mul

input1
input2output2"AddmulZ
input1


Z
input2


b
output
b
output1


b
output2


B
20 changes: 13 additions & 7 deletions config-model/src/test/cfg/application/stateless_eval/mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,31 @@
import onnx
from onnx import helper, TensorProto

INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
INPUT1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
INPUT2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
OUTPUT1 = helper.make_tensor_value_info('output1', TensorProto.FLOAT, [1])
OUTPUT2 = helper.make_tensor_value_info('output2', TensorProto.FLOAT, [1])

nodes = [
helper.make_node(
'Mul',
['input1', 'input2'],
['output'],
['output1'],
),
helper.make_node(
'Add',
['input1', 'input2'],
['output2'],
),
]
graph_def = helper.make_graph(
nodes,
'mul',
[
INPUT_1,
INPUT_2
INPUT1,
INPUT2
],
[OUTPUT],
[OUTPUT1, OUTPUT2],
)
model_def = helper.make_model(graph_def, producer_name='mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
onnx.save(model_def, 'mul.onnx')
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,18 @@ public void testModelsEvaluatorTester() {
assertEquals(3, modelsEvaluator.models().size());

// ONNX model evaluation
FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul");
FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul", "output1");
Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]");
Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]");
Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate();
assertEquals(6.0, output.sum().asDouble(), 1e-9);

FunctionEvaluator eval = modelsEvaluator.evaluatorOf("mul");
output = eval.bind("input1", input1).bind("input2", input2).evaluate();
assertEquals(6.0, output.sum().asDouble(), 1e-9);
assertEquals(6.0, eval.result("output1").sum().asDouble(), 1e-9);
assertEquals(5.0, eval.result("output2").sum().asDouble(), 1e-9);

// LightGBM model evaluation
FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression");
lgbm.bind("numerical_1", 0.1).bind("numerical_2", 0.2).bind("categorical_1", "a").bind("categorical_2", "i");
Expand Down
4 changes: 3 additions & 1 deletion model-evaluation/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
"public"
],
"methods": [
"public com.yahoo.tensor.Tensor result(java.lang.String)",
"public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, com.yahoo.tensor.Tensor)",
"public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, double)",
"public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, java.lang.String)",
"public ai.vespa.models.evaluation.FunctionEvaluator setMissingValue(com.yahoo.tensor.Tensor)",
"public ai.vespa.models.evaluation.FunctionEvaluator setMissingValue(double)",
"public com.yahoo.tensor.Tensor evaluate()",
"public com.yahoo.searchlib.rankingexpression.ExpressionFunction function()",
"public ai.vespa.models.evaluation.LazyArrayContext context()"
"public ai.vespa.models.evaluation.LazyArrayContext context()",
"public java.util.List outputs()"
],
"fields": []
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,37 @@
import com.yahoo.tensor.TensorType;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* An evaluator which can be used to evaluate a single function once.
* An evaluator which can be used to evaluate a function once.
*
* @author bratseth
*/
// This wraps all access to the context and the ranking expression to avoid incorrect usage
public class FunctionEvaluator {

private final ExpressionFunction function;
private final LazyArrayContext context;
private final List<ExpressionFunction> functions;
private final Map<String, LazyArrayContext> contexts;
private final Map<String, Tensor> results;
private boolean evaluated = false;

FunctionEvaluator(ExpressionFunction function, LazyArrayContext context) {
this.function = function;
this.context = context;
this(List.of(function), Map.of(function.getName(), context));
}

FunctionEvaluator(List<ExpressionFunction> functions, Map<String, LazyArrayContext> contexts) {
this.functions = List.copyOf(functions);
this.contexts = Map.copyOf(contexts);
this.results = new HashMap<>();
}

public Tensor result(String name) {
return results.get(name);
}

/**
Expand All @@ -38,15 +51,14 @@ public class FunctionEvaluator {
public FunctionEvaluator bind(String name, Tensor value) {
if (evaluated)
throw new IllegalStateException("Cannot bind a new value in a used evaluator");
TensorType requiredType = function.argumentTypes().get(name);
if (requiredType == null)
throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function +
". Expected arguments: " + function.argumentTypes().entrySet().stream()
.map(e -> e.getKey() + ": " + e.getValue())
.collect(Collectors.joining(", ")));
if ( ! value.type().isAssignableTo(requiredType))
throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type());
context.put(name, new TensorValue(value));
for (ExpressionFunction function : functions) {
if (function.argumentTypes().containsKey(name)) {
TensorType requiredType = function.argumentTypes().get(name);
if ( ! value.type().isAssignableTo(requiredType))
throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type());
contexts.get(function.getName()).put(name, new TensorValue(value));
}
}
return this;
}

Expand All @@ -73,7 +85,11 @@ public FunctionEvaluator bind(String name, double value) {
public FunctionEvaluator bind(String name, String value) {
if (evaluated)
throw new IllegalStateException("Cannot bind a new value in a used evaluator");
context.put(name, new StringValue(value));
for (ExpressionFunction function : functions) {
if (function.argumentTypes().containsKey(name)) {
contexts.get(function.getName()).put(name, new StringValue(value));
}
}
return this;
}

Expand All @@ -86,7 +102,9 @@ public FunctionEvaluator bind(String name, String value) {
public FunctionEvaluator setMissingValue(Tensor value) {
if (evaluated)
throw new IllegalStateException("Cannot change the missing value in a used evaluator");
context.setMissingValue(value);
for (LazyArrayContext context : contexts.values()) {
context.setMissingValue(value);
}
return this;
}

Expand All @@ -101,15 +119,32 @@ public FunctionEvaluator setMissingValue(double value) {
}

public Tensor evaluate() {
for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
checkArgument(argument.getKey(), argument.getValue());
checkArguments();
evaluateOnnxModels();

Tensor defaultResult = null;
for (ExpressionFunction function: functions) {
LazyArrayContext context = contexts.get(function.getName());
Tensor result = function.getBody().evaluate(context).asTensor();
results.put(function.getName(), function.getBody().evaluate(context).asTensor());
if (defaultResult == null) {
defaultResult = result;
}
}
evaluated = true;
evaluateOnnxModels();
return function.getBody().evaluate(context).asTensor();
return defaultResult;
}

private void checkArgument(String name, TensorType type) {
void checkArguments() {
for (ExpressionFunction function : functions) {
LazyArrayContext context = contexts.get(function.getName());
for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
checkArgument(argument.getKey(), argument.getValue(), context);
}
}
}

private void checkArgument(String name, TensorType type, LazyArrayContext context) {
if (context.isMissing(name))
throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type);
if (! context.get(name).type().isAssignableTo(type))
Expand All @@ -120,23 +155,52 @@ private void checkArgument(String name, TensorType type) {
* Evaluate ONNX models (if not already evaluated) and add the result back to the context.
*/
private void evaluateOnnxModels() {
for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) {
String onnxFeature = entry.getKey();
OnnxModel onnxModel = entry.getValue();
if (context.get(onnxFeature).equals(context.defaultValue())) {
Map<String, Tensor> inputs = new HashMap<>();
for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) {
inputs.put(input.getKey(), context.get(input.getKey()).asTensor());
Set<OnnxModel> onnxModels = new HashSet<>();
for (LazyArrayContext context : contexts.values()) {
onnxModels.addAll(context.onnxModels().values());
}

for (OnnxModel onnxModel : onnxModels) {

// Gather inputs from all functions. Inputs with the same name must have the same value.
Map<String, Tensor> inputs = new HashMap<>();
for (LazyArrayContext context : contexts.values()) {
for (OnnxModel functionModel : context.onnxModels().values()) {
if (functionModel.name().equals(onnxModel.name())) {
for (String inputName: onnxModel.inputs().keySet()) {
inputs.put(inputName, context.get(inputName).asTensor());
}
}
}
Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model
context.put(onnxFeature, new TensorValue(result));
}

// Evaluate model once.
Map<String, Tensor> outputs = onnxModel.evaluate(inputs);

// Add outputs back to the context of the functions that need them; they won't be recalculated.
for (ExpressionFunction function : functions) {
LazyArrayContext context = contexts.get(function.getName());
for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) {
String onnxFeature = entry.getKey();
OnnxModel functionModel = entry.getValue();
if (functionModel.name().equals(onnxModel.name())) {
Tensor result = outputs.get(function.getName()); // Function name is output of model
context.put(onnxFeature, new TensorValue(result));
}
}
}

}
}

/** Returns the function evaluated by this */
public ExpressionFunction function() { return function; }
/** Returns the default function evaluated by this */
public ExpressionFunction function() { return functions.get(0); }

public LazyArrayContext context() { return contexts.get(function().getName()); }

public LazyArrayContext context() { return context; }
/** Returns the names of the outputs of this function */
public List<String> outputs() {
return functions.stream().map(ExpressionFunction::getName).collect(Collectors.toList());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ private static class IndexedBindings {
/** The mapping from variable name to index */
private final ImmutableMap<String, Integer> nameToIndex;

/** The names which needs to be bound externally when invoking this (i.e not constant or invocation */
/** The names which needs to be bound externally when invoking this (i.e. not constant or invocation) */
private final ImmutableSet<String> arguments;

/** The current values set */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -183,9 +182,7 @@ ExpressionFunction requireReferencedFunction(FunctionReference reference) {
*/
public FunctionEvaluator evaluatorOf(String ... names) { // TODO: Parameter overloading?
if (names.length == 0) {
if (functions.size() > 1)
throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given");
return evaluatorOf(functions.get(0));
return evaluatorOf(functions);
}
else if (names.length == 1) {
String name = names[0];
Expand Down Expand Up @@ -232,6 +229,15 @@ private FunctionEvaluator evaluatorOf(ExpressionFunction function) {
return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy());
}

/** Returns a single-use evaluator of a function */
private FunctionEvaluator evaluatorOf(List<ExpressionFunction> functions) {
Map<String, LazyArrayContext> contexts = new HashMap<>();
for (ExpressionFunction function : functions) {
contexts.put(function.getName(), requireContextPrototype(function.getName()).copy());
}
return new FunctionEvaluator(functions, contexts);
}

private void throwUndeterminedFunction(String message) {
throw new IllegalArgumentException(message + ". Available functions: " +
functions.stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ public Tensor evaluate(Map<String, Tensor> inputs, String output) {
return evaluator().evaluate(inputs, output);
}

public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
return evaluator().evaluate(inputs);
}

private OnnxEvaluator evaluator() {
if (evaluator == null) {
throw new IllegalStateException("ONNX model has not been loaded.");
Expand Down
Loading