Skip to content

Commit

Permalink
Merge pull request #24814 from vespa-engine/revert-24813-revert-24773…
Browse files Browse the repository at this point in the history
…-revert-24760-balder/model-importing-code-in-config-model-3

Revert "Revert "Revert "Balder/model importing code in config model [run-systemtest]"""
  • Loading branch information
baldersheim authored Nov 9, 2022
2 parents 606bc6b + ec1413c commit ba3ac72
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 56 deletions.
6 changes: 6 additions & 0 deletions config-model-fat/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
<artifactId>model-integration</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
package com.yahoo.vespa.model;

import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.Version;
import com.yahoo.component.provider.ComponentRegistry;
Expand Down Expand Up @@ -61,6 +56,7 @@ public class VespaModelFactory implements ModelFactory {
/** Creates a factory for Vespa models for this version of the source */
@Inject
public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry,
ComponentRegistry<MlModelImporter> modelImporters,
ComponentRegistry<Validator> additionalValidators,
Zone zone) {
this.version = new Version(VespaVersion.major, VespaVersion.minor, VespaVersion.micro);
Expand All @@ -71,12 +67,7 @@ public VespaModelFactory(ComponentRegistry<ConfigModelPlugin> pluginRegistry,
}
}
this.configModelRegistry = new MapConfigModelRegistry(modelBuilders);
this.modelImporters = List.of(
new VespaImporter(),
new OnnxImporter(),
new TensorFlowImporter(),
new XGBoostImporter(),
new LightGBMImporter());
this.modelImporters = modelImporters.allComponents();
this.zone = zone;
this.additionalValidators = List.copyOf(additionalValidators.allComponents());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ private static Map<String, ExpressionFunction> convertAndStore(ImportedMlModel m
ModelStore store) {
// Add constants
Set<String> constantsReplacedByFunctions = new HashSet<>();
model.smallConstantTensors().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
model.largeConstantTensors().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
constantsReplacedByFunctions, k, v));

// Add functions
Expand Down Expand Up @@ -294,7 +294,8 @@ private static Map<String, ExpressionFunction> convertStored(ModelStore store, R
}

private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName,
Tensor constantValue) {
String constantValueString) {
Tensor constantValue = Tensor.from(constantValueString);
store.writeSmallConstant(constantName, constantValue);
Reference name = FeatureNames.asConstantFeature(constantName);
profile.add(new RankProfile.Constant(name, constantValue));
Expand All @@ -305,7 +306,8 @@ private static void transformLargeConstant(ModelStore store,
QueryProfileRegistry queryProfiles,
Set<String> constantsReplacedByFunctions,
String constantName,
Tensor constantValue) {
String constantValueString) {
Tensor constantValue = Tensor.from(constantValueString);
RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName);
if (rankingExpressionFunctionOverridingConstant != null) {
TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles));
Expand Down
5 changes: 0 additions & 5 deletions fat-model-dependencies/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,6 @@
<artifactId>model-evaluation</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
<artifactId>model-integration</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
<artifactId>metrics</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,38 +80,21 @@ public Optional<String> inputTypeSpec(String input) {
return Optional.ofNullable(inputs.get(input)).map(TensorType::toString);
}

/**
* Returns an immutable map of the small constants of this.
* These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
@Override
public Map<String, Tensor> smallConstantTensors() { return Map.copyOf(smallConstants); }
/**
* Returns an immutable map of the small constants of this, represented as strings on the standard tensor form.
* These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
* @deprecated Use smallConstantTensors instead
*/
@Override
@SuppressWarnings("removal")
@Deprecated(forRemoval = true)
public Map<String, String> smallConstants() { return asStrings(smallConstants); }

boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); }

/**
* Returns an immutable map of the large constants of this.
* These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
* For TensorFlow this corresponds to Variable files stored separately.
*/
@Override
public Map<String, Tensor> largeConstantTensors() { return Map.copyOf(largeConstants); }
/**
* Returns an immutable map of the large constants of this, represented as strings on the standard tensor form.
* These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
* @deprecated Use largeConstantTensors instead
*/
@Override
@SuppressWarnings("removal")
@Deprecated(forRemoval = true)
public Map<String, String> largeConstants() { return asStrings(largeConstants); }

boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.configmodelview;

import com.yahoo.tensor.Tensor;

import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -23,12 +21,8 @@ enum ModelType {
ModelType modelType();

Optional<String> inputTypeSpec(String input);
@Deprecated(forRemoval = true)
Map<String, String> smallConstants();
@Deprecated(forRemoval = true)
Map<String, String> largeConstants();
Map<String, Tensor> smallConstantTensors();
Map<String, Tensor> largeConstantTensors();
Map<String, String> functions();
List<ImportedMlFunction> outputExpressions();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ public void testMnistSoftmaxImport() {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx").asNative();

// Check constants
assertEquals(2, model.largeConstantTensors().size());
assertEquals(2, model.largeConstants().size());

Tensor constant0 = model.largeConstantTensors().get("test_Variable");
Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable"));
assertNotNull(constant0);
assertEquals(new TensorType.Builder(TensorType.Value.FLOAT).indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());

Tensor constant1 = model.largeConstantTensors().get("test_Variable_1");
Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1"));
assertNotNull(constant1);
assertEquals(new TensorType.Builder(TensorType.Value.FLOAT).indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ else if (node instanceof CompositeNode) {

static Context contextFrom(ImportedModel result) {
TestableModelContext context = new TestableModelContext();
result.largeConstantTensors().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstantTensors().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
return context;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
Expand Down Expand Up @@ -38,12 +39,12 @@ private void assertModel(ImportedModel model) {
assertEquals("tensor(name{},x[3])", model.inputs().get("input1").toString());
assertEquals("tensor(x[3])", model.inputs().get("input2").toString());

assertEquals(2, model.smallConstantTensors().size());
assertEquals("tensor(x[3]):[0.5, 1.5, 2.5]", model.smallConstantTensors().get("constant1").toString());
assertEquals("tensor():{3.0}", model.smallConstantTensors().get("constant2").toString());
assertEquals(2, model.smallConstants().size());
assertEquals("tensor(x[3]):[0.5, 1.5, 2.5]", model.smallConstants().get("constant1"));
assertEquals("tensor():{3.0}", model.smallConstants().get("constant2"));

assertEquals(1, model.largeConstantTensors().size());
assertEquals("tensor(x[3]):[0.5, 1.5, 2.5]", model.largeConstantTensors().get("constant1asLarge").toString());
assertEquals(1, model.largeConstants().size());
assertEquals("tensor(x[3]):[0.5, 1.5, 2.5]", model.largeConstants().get("constant1asLarge"));

assertEquals(2, model.expressions().size());
assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2",
Expand Down Expand Up @@ -71,8 +72,8 @@ public void testEmpty() {
assertTrue(model.expressions().isEmpty());
assertTrue(model.functions().isEmpty());
assertTrue(model.inputs().isEmpty());
assertTrue(model.largeConstantTensors().isEmpty());
assertTrue(model.smallConstantTensors().isEmpty());
assertTrue(model.largeConstants().isEmpty());
assertTrue(model.smallConstants().isEmpty());
}

@Test
Expand Down

0 comments on commit ba3ac72

Please sign in to comment.