diff --git a/config-model/src/test/cfg/application/stateless_eval/mul.onnx b/config-model/src/test/cfg/application/stateless_eval/mul.onnx index 087e2c3427f1..26411c96986c 100644 --- a/config-model/src/test/cfg/application/stateless_eval/mul.onnx +++ b/config-model/src/test/cfg/application/stateless_eval/mul.onnx @@ -1,7 +1,10 @@ -mul.py:f - +mul.py:Ÿ + input1 -input2output"MulmulZ +input2output1"Mul + +input1 +input2output2"AddmulZ input1  @@ -9,8 +12,12 @@ input2  -b -output +b +output1 + + +b +output2  B \ No newline at end of file diff --git a/config-model/src/test/cfg/application/stateless_eval/mul.py b/config-model/src/test/cfg/application/stateless_eval/mul.py index 9fcb8612af9a..6bbc4e232007 100755 --- a/config-model/src/test/cfg/application/stateless_eval/mul.py +++ b/config-model/src/test/cfg/application/stateless_eval/mul.py @@ -2,25 +2,31 @@ import onnx from onnx import helper, TensorProto -INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1]) -INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1]) -OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1]) +INPUT1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1]) +INPUT2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1]) +OUTPUT1 = helper.make_tensor_value_info('output1', TensorProto.FLOAT, [1]) +OUTPUT2 = helper.make_tensor_value_info('output2', TensorProto.FLOAT, [1]) nodes = [ helper.make_node( 'Mul', ['input1', 'input2'], - ['output'], + ['output1'], + ), + helper.make_node( + 'Add', + ['input1', 'input2'], + ['output2'], ), ] graph_def = helper.make_graph( nodes, 'mul', [ - INPUT_1, - INPUT_2 + INPUT1, + INPUT2 ], - [OUTPUT], + [OUTPUT1, OUTPUT2], ) model_def = helper.make_model(graph_def, producer_name='mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) onnx.save(model_def, 'mul.onnx') diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java index 5630d3cc186b..8ed229b2ff57 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java @@ -21,12 +21,18 @@ public void testModelsEvaluatorTester() { assertEquals(3, modelsEvaluator.models().size()); // ONNX model evaluation - FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul"); + FunctionEvaluator mul = modelsEvaluator.evaluatorOf("mul", "output1"); Tensor input1 = Tensor.from("tensor(d0[1]):[2]"); Tensor input2 = Tensor.from("tensor(d0[1]):[3]"); Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate(); assertEquals(6.0, output.sum().asDouble(), 1e-9); + FunctionEvaluator eval = modelsEvaluator.evaluatorOf("mul"); + output = eval.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(6.0, output.sum().asDouble(), 1e-9); + assertEquals(6.0, eval.result("output1").sum().asDouble(), 1e-9); + assertEquals(5.0, eval.result("output2").sum().asDouble(), 1e-9); + // LightGBM model evaluation FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression"); lgbm.bind("numerical_1", 0.1).bind("numerical_2", 0.2).bind("categorical_1", "a").bind("categorical_2", "i"); diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index 6728d5cd9b44..71dd7ffc2ebb 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -6,6 +6,7 @@ "public" ], "methods": [ + "public com.yahoo.tensor.Tensor result(java.lang.String)", "public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, com.yahoo.tensor.Tensor)", "public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, double)", "public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, java.lang.String)", @@ -13,7 +14,8 @@ "public ai.vespa.models.evaluation.FunctionEvaluator setMissingValue(double)", "public com.yahoo.tensor.Tensor evaluate()", "public com.yahoo.searchlib.rankingexpression.ExpressionFunction function()", - "public ai.vespa.models.evaluation.LazyArrayContext context()" + "public ai.vespa.models.evaluation.LazyArrayContext context()", + "public java.util.List outputs()" ], "fields": [] }, diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 6af33e29e62c..7a992cb7aa90 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -8,24 +8,37 @@ import com.yahoo.tensor.TensorType; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; /** - * An evaluator which can be used to evaluate a single function once. + * An evaluator which can be used to evaluate a function once. * * @author bratseth */ // This wraps all access to the context and the ranking expression to avoid incorrect usage public class FunctionEvaluator { - private final ExpressionFunction function; - private final LazyArrayContext context; + private final List functions; + private final Map contexts; + private final Map results; private boolean evaluated = false; FunctionEvaluator(ExpressionFunction function, LazyArrayContext context) { - this.function = function; - this.context = context; + this(List.of(function), Map.of(function.getName(), context)); + } + + FunctionEvaluator(List functions, Map contexts) { + this.functions = List.copyOf(functions); + this.contexts = Map.copyOf(contexts); + this.results = new HashMap<>(); + } + + public Tensor result(String name) { + return results.get(name); } /** @@ -38,15 +51,14 @@ public class FunctionEvaluator { public FunctionEvaluator bind(String name, Tensor value) { if (evaluated) throw new IllegalStateException("Cannot bind a new value in a used evaluator"); - TensorType requiredType = function.argumentTypes().get(name); - if (requiredType == null) - throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function + - ". Expected arguments: " + function.argumentTypes().entrySet().stream() - .map(e -> e.getKey() + ": " + e.getValue()) - .collect(Collectors.joining(", "))); - if ( ! value.type().isAssignableTo(requiredType)) - throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); - context.put(name, new TensorValue(value)); + for (ExpressionFunction function : functions) { + if (function.argumentTypes().containsKey(name)) { + TensorType requiredType = function.argumentTypes().get(name); + if ( ! value.type().isAssignableTo(requiredType)) + throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); + contexts.get(function.getName()).put(name, new TensorValue(value)); + } + } return this; } @@ -73,7 +85,11 @@ public FunctionEvaluator bind(String name, double value) { public FunctionEvaluator bind(String name, String value) { if (evaluated) throw new IllegalStateException("Cannot bind a new value in a used evaluator"); - context.put(name, new StringValue(value)); + for (ExpressionFunction function : functions) { + if (function.argumentTypes().containsKey(name)) { + contexts.get(function.getName()).put(name, new StringValue(value)); + } + } return this; } @@ -86,7 +102,9 @@ public FunctionEvaluator bind(String name, String value) { public FunctionEvaluator setMissingValue(Tensor value) { if (evaluated) throw new IllegalStateException("Cannot change the missing value in a used evaluator"); - context.setMissingValue(value); + for (LazyArrayContext context : contexts.values()) { + context.setMissingValue(value); + } return this; } @@ -101,15 +119,32 @@ public FunctionEvaluator setMissingValue(double value) { } public Tensor evaluate() { - for (Map.Entry argument : function.argumentTypes().entrySet()) { - checkArgument(argument.getKey(), argument.getValue()); + checkArguments(); + evaluateOnnxModels(); + + Tensor defaultResult = null; + for (ExpressionFunction function: functions) { + LazyArrayContext context = contexts.get(function.getName()); + Tensor result = function.getBody().evaluate(context).asTensor(); + results.put(function.getName(), function.getBody().evaluate(context).asTensor()); + if (defaultResult == null) { + defaultResult = result; + } } evaluated = true; - evaluateOnnxModels(); - return function.getBody().evaluate(context).asTensor(); + return defaultResult; } - private void checkArgument(String name, TensorType type) { + void checkArguments() { + for (ExpressionFunction function : functions) { + LazyArrayContext context = contexts.get(function.getName()); + for (Map.Entry argument : function.argumentTypes().entrySet()) { + checkArgument(argument.getKey(), argument.getValue(), context); + } + } + } + + private void checkArgument(String name, TensorType type, LazyArrayContext context) { if (context.isMissing(name)) throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type); if (! context.get(name).type().isAssignableTo(type)) @@ -120,23 +155,52 @@ private void checkArgument(String name, TensorType type) { * Evaluate ONNX models (if not already evaluated) and add the result back to the context. */ private void evaluateOnnxModels() { - for (Map.Entry entry : context().onnxModels().entrySet()) { - String onnxFeature = entry.getKey(); - OnnxModel onnxModel = entry.getValue(); - if (context.get(onnxFeature).equals(context.defaultValue())) { - Map inputs = new HashMap<>(); - for (Map.Entry input: onnxModel.inputs().entrySet()) { - inputs.put(input.getKey(), context.get(input.getKey()).asTensor()); + Set onnxModels = new HashSet<>(); + for (LazyArrayContext context : contexts.values()) { + onnxModels.addAll(context.onnxModels().values()); + } + + for (OnnxModel onnxModel : onnxModels) { + + // Gather inputs from all functions. Inputs with the same name must have the same value. + Map inputs = new HashMap<>(); + for (LazyArrayContext context : contexts.values()) { + for (OnnxModel functionModel : context.onnxModels().values()) { + if (functionModel.name().equals(onnxModel.name())) { + for (String inputName: onnxModel.inputs().keySet()) { + inputs.put(inputName, context.get(inputName).asTensor()); + } + } } - Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model - context.put(onnxFeature, new TensorValue(result)); } + + // Evaluate model once. + Map outputs = onnxModel.evaluate(inputs); + + // Add outputs back to the context of the functions that need them; they won't be recalculated. + for (ExpressionFunction function : functions) { + LazyArrayContext context = contexts.get(function.getName()); + for (Map.Entry entry : context.onnxModels().entrySet()) { + String onnxFeature = entry.getKey(); + OnnxModel functionModel = entry.getValue(); + if (functionModel.name().equals(onnxModel.name())) { + Tensor result = outputs.get(function.getName()); // Function name is output of model + context.put(onnxFeature, new TensorValue(result)); + } + } + } + } } - /** Returns the function evaluated by this */ - public ExpressionFunction function() { return function; } + /** Returns the default function evaluated by this */ + public ExpressionFunction function() { return functions.get(0); } + + public LazyArrayContext context() { return contexts.get(function().getName()); } - public LazyArrayContext context() { return context; } + /** Returns the names of the outputs of this function */ + public List outputs() { + return functions.stream().map(ExpressionFunction::getName).collect(Collectors.toList()); + } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index d97235d11d25..cc53f38f800a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -153,7 +153,7 @@ private static class IndexedBindings { /** The mapping from variable name to index */ private final ImmutableMap nameToIndex; - /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */ + /** The names which needs to be bound externally when invoking this (i.e. not constant or invocation) */ private final ImmutableSet arguments; /** The current values set */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 8af5f7bc4996..ab24986e542b 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -6,14 +6,13 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -183,9 +182,7 @@ ExpressionFunction requireReferencedFunction(FunctionReference reference) { */ public FunctionEvaluator evaluatorOf(String ... names) { // TODO: Parameter overloading? if (names.length == 0) { - if (functions.size() > 1) - throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given"); - return evaluatorOf(functions.get(0)); + return evaluatorOf(functions); } else if (names.length == 1) { String name = names[0]; @@ -232,6 +229,15 @@ private FunctionEvaluator evaluatorOf(ExpressionFunction function) { return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy()); } + /** Returns a single-use evaluator of a function */ + private FunctionEvaluator evaluatorOf(List functions) { + Map contexts = new HashMap<>(); + for (ExpressionFunction function : functions) { + contexts.put(function.getName(), requireContextPrototype(function.getName()).copy()); + } + return new FunctionEvaluator(functions, contexts); + } + private void throwUndeterminedFunction(String message) { throw new IllegalArgumentException(message + ". Available functions: " + functions.stream().map(f -> f.getName()).collect(Collectors.joining(", "))); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index 19a9a1dccd5f..06045b07f7c9 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -50,6 +50,10 @@ public Tensor evaluate(Map inputs, String output) { return evaluator().evaluate(inputs, output); } + public Map evaluate(Map inputs) { + return evaluator().evaluate(inputs); + } + private OnnxEvaluator evaluator() { if (evaluator == null) { throw new IllegalStateException("ONNX model has not been loaded."); diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index b0e2be26f8aa..d5ae1bbf5919 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -4,6 +4,9 @@ import ai.vespa.models.evaluation.FunctionEvaluator; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; @@ -15,6 +18,7 @@ import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.yolean.Exceptions; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; import java.net.URI; @@ -65,14 +69,14 @@ public HttpResponse handle(HttpRequest request) { } return listModelInformation(request, model, function); - } catch (IllegalArgumentException e) { + } catch (IllegalArgumentException | IOException e) { return new ErrorResponse(404, Exceptions.toMessageString(e)); } catch (IllegalStateException e) { // On missing bindings return new ErrorResponse(400, Exceptions.toMessageString(e)); } } - private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) { + private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) throws IOException { FunctionEvaluator evaluator = model.evaluatorOf(function); property(request, missingValueKey).ifPresent(missingValue -> evaluator.setMissingValue(Tensor.from(missingValue))); @@ -87,16 +91,37 @@ private HttpResponse evaluateModel(HttpRequest request, Model model, String[] fu } } } - Tensor result = evaluator.evaluate(); + String format = property(request, "format.tensors").orElse("default"); + if (evaluator.outputs().size() > 1) { + evaluator.evaluate(); + return new Response(200, encodeMultipleResults(evaluator, format)); + } + return new Response(200, encodeSingleResult(evaluator.evaluate(), format)); + } - Optional format = property(request, "format.tensors"); - if (format.isPresent() && format.get().equalsIgnoreCase("short")) { - return new Response(200, JsonFormat.encodeShortForm(result)); + private byte[] encodeMultipleResults(FunctionEvaluator evaluator, String format) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); + g.writeStartObject(); + for (String output : evaluator.outputs()) { + g.writeFieldName(output); + g.writeRawValue(new String(encodeSingleResult(evaluator.result(output), format))); } - else if (format.isPresent() && format.get().equalsIgnoreCase("string")) { - return new Response(200, result.toString().getBytes(StandardCharsets.UTF_8)); + g.writeEndObject(); + g.close(); + return out.toByteArray(); + } + + private byte[] encodeSingleResult(Tensor tensor, String format) { + if (format != null) { + if (format.equalsIgnoreCase("short")) { + return JsonFormat.encodeShortForm(tensor); + } + if (format.equalsIgnoreCase("string")) { + return tensor.toString().getBytes(StandardCharsets.UTF_8); + } } - return new Response(200, JsonFormat.encode(result)); + return JsonFormat.encode(tensor); } private HttpResponse listAllModels(HttpRequest request) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index 4cb52216137a..3e065d25ad20 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -95,8 +95,8 @@ public void testBindingValidation() { evaluator.bind("argNone", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}")); evaluator.evaluate(); } - catch (IllegalArgumentException e) { - assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])", + catch (IllegalStateException e) { + assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})", Exceptions.toMessageString(e)); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java index a15c35fe8548..ae77af264a1d 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java @@ -18,6 +18,7 @@ import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; /** @@ -34,15 +35,22 @@ public void testOnnxEvaluation() { assertTrue(models.models().containsKey("add_mul")); assertTrue(models.models().containsKey("one_layer")); + Tensor input1 = Tensor.from("tensor(d0[1]):[2]"); + Tensor input2 = Tensor.from("tensor(d0[1]):[3]"); + FunctionEvaluator function = models.evaluatorOf("add_mul", "output1"); - function.bind("input1", Tensor.from("tensor(d0[1]):[2]")); - function.bind("input2", Tensor.from("tensor(d0[1]):[3]")); - assertEquals(6.0, function.evaluate().sum().asDouble(), delta); + Tensor result = function.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(6.0, result.sum().asDouble(), delta); function = models.evaluatorOf("add_mul", "output2"); - function.bind("input1", Tensor.from("tensor(d0[1]):[2]")); - function.bind("input2", Tensor.from("tensor(d0[1]):[3]")); - assertEquals(5.0, function.evaluate().sum().asDouble(), delta); + result = function.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(5.0, result.sum().asDouble(), delta); + + function = models.evaluatorOf("add_mul"); // contains two models + result = function.bind("input1", input1).bind("input2", input2).evaluate(); + assertEquals(6.0, result.sum().asDouble(), delta); + assertEquals(6.0, function.result("output1").sum().asDouble(), delta); + assertEquals(5.0, function.result("output2").sum().asDouble(), delta); function = models.evaluatorOf("one_layer"); function.bind("input", Tensor.from("tensor(d0[2],d1[3]):[[0.1, 0.2, 0.3],[0.4,0.5,0.6]]")); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 33e56d5d465e..8c7be4e7be95 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -205,13 +205,6 @@ public void testMnistSavedTypeDetails() { handler.assertResponse(url, 200, expected); } - @Test - public void testMnistSavedEvaluateDefaultFunctionShouldFail() { - String url = "http://localhost/model-evaluation/v1/mnist_saved/eval"; - String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_function_mnist_saved_dnn_hidden1_add, serving_default.y\"}"; - handler.assertResponse(url, 404, expected); - } - @Test public void testVespaModelShortOutput() { Map properties = new HashMap<>(); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java index 6014bd7c7ef4..74715ad96a24 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java @@ -61,8 +61,8 @@ public void testModelInfo() { @Test public void testEvaluationWithoutSpecifyingOutput() { String url = "http://localhost/model-evaluation/v1/add_mul/eval"; - String expected = "{\"error\":\"More than one function is available in model 'add_mul', but no name is given. Available functions: output1, output2\"}"; - handler.assertResponse(url, 404, expected); + String expected = "{\"error\":\"Argument 'input1' must be bound to a value of type tensor(d0[1])\"}"; + handler.assertResponse(url, 400, expected); } @Test @@ -92,6 +92,19 @@ public void testEvaluationOutput2() { handler.assertResponse(url, properties, 200, expected); } + @Test + public void testEvaluateAllOutputs() { + Map properties = new HashMap<>(); + properties.put("input1", "tensor(d0[1]):[2]"); + properties.put("input2", "tensor(d0[1]):[3]"); + String url = "http://localhost/model-evaluation/v1/add_mul/eval"; // remember to add to discovery! + String expected = "{" + + "\"output1\":{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}," + // output1 is a mul + "\"output2\":{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}" + // output1 is an add + "}"; + handler.assertResponse(url, properties, 200, expected); + } + @Test public void testBatchDimensionModelInfo() { String url = "http://localhost/model-evaluation/v1/one_layer";