Skip to content

Commit

Permalink
Register model with onnx model options
Browse files Browse the repository at this point in the history
  • Loading branch information
Harald Musum committed Nov 20, 2023
1 parent e8c0a04 commit 9d28a47
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 16 deletions.
53 changes: 51 additions & 2 deletions config-model-api/abi-spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -1455,7 +1455,9 @@
"methods" : [
"public abstract long aggregatedModelCostInBytes()",
"public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile)",
"public abstract void registerModel(java.net.URI)"
"public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)",
"public abstract void registerModel(java.net.URI)",
"public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)"
],
"fields" : [ ]
},
Expand All @@ -1473,7 +1475,9 @@
"public com.yahoo.config.model.api.OnnxModelCost$Calculator newCalculator(com.yahoo.config.application.api.ApplicationPackage, com.yahoo.config.provision.ApplicationId)",
"public long aggregatedModelCostInBytes()",
"public void registerModel(com.yahoo.config.application.api.ApplicationFile)",
"public void registerModel(java.net.URI)"
"public void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)",
"public void registerModel(java.net.URI)",
"public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)"
],
"fields" : [ ]
},
Expand All @@ -1491,6 +1495,51 @@
],
"fields" : [ ]
},
"com.yahoo.config.model.api.OnnxModelOptions$GpuDevice" : {
"superClass" : "java.lang.Record",
"interfaces" : [ ],
"attributes" : [
"public",
"final",
"record"
],
"methods" : [
"public void <init>(int, boolean)",
"public void <init>(int)",
"public final java.lang.String toString()",
"public final int hashCode()",
"public final boolean equals(java.lang.Object)",
"public int deviceNumber()",
"public boolean required()"
],
"fields" : [ ]
},
"com.yahoo.config.model.api.OnnxModelOptions" : {
"superClass" : "java.lang.Record",
"interfaces" : [ ],
"attributes" : [
"public",
"final",
"record"
],
"methods" : [
"public void <init>(java.lang.String, int, int, com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)",
"public void <init>(java.util.Optional, java.util.Optional, java.util.Optional, java.util.Optional)",
"public static com.yahoo.config.model.api.OnnxModelOptions empty()",
"public com.yahoo.config.model.api.OnnxModelOptions withExecutionMode(java.lang.String)",
"public com.yahoo.config.model.api.OnnxModelOptions withInterOpThreads(java.lang.Integer)",
"public com.yahoo.config.model.api.OnnxModelOptions withIntraOpThreads(java.lang.Integer)",
"public com.yahoo.config.model.api.OnnxModelOptions withGpuDevice(com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)",
"public final java.lang.String toString()",
"public final int hashCode()",
"public final boolean equals(java.lang.Object)",
"public java.util.Optional executionMode()",
"public java.util.Optional interOpThreads()",
"public java.util.Optional intraOpThreads()",
"public java.util.Optional gpuDevice()"
],
"fields" : [ ]
},
"com.yahoo.config.model.api.PortInfo" : {
"superClass" : "java.lang.Object",
"interfaces" : [ ],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
/**
* @author bjorncs
*/
// TODO: Rename
public interface OnnxModelCost {

Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId);

interface Calculator {
long aggregatedModelCostInBytes();
void registerModel(ApplicationFile path);
void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions);
void registerModel(URI uri);
void registerModel(URI uri, OnnxModelOptions onnxModelOptions);
}

static OnnxModelCost disabled() { return new DisabledOnnxModelCost(); }
Expand All @@ -26,7 +29,9 @@ class DisabledOnnxModelCost implements OnnxModelCost, Calculator {
@Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; }
@Override public long aggregatedModelCostInBytes() {return 0;}
@Override public void registerModel(ApplicationFile path) {}
@Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {}
@Override public void registerModel(URI uri) {}
@Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {}
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.component;
package com.yahoo.config.model.api;

import java.util.Optional;

Expand All @@ -12,7 +12,11 @@
public record OnnxModelOptions(Optional<String> executionMode, Optional<Integer> interOpThreads,
Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice) {

public static OnnxModelOptions empty() {
public OnnxModelOptions(String executionMode, int interOpThreads, int intraOpThreads, GpuDevice gpuDevice) {
this(Optional.of(executionMode), Optional.of(interOpThreads), Optional.of(intraOpThreads), Optional.of(gpuDevice));
}

public static OnnxModelOptions empty() {
return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
}

Expand Down
4 changes: 3 additions & 1 deletion config-model/src/main/java/com/yahoo/schema/OnnxModel.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;

import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.container.component.OnnxModelOptions;
import com.yahoo.vespa.model.ml.OnnxModelInfo;

import java.util.Collections;
Expand Down Expand Up @@ -171,4 +171,6 @@ public Optional<OnnxModelOptions.GpuDevice> getGpuDevice() {
return onnxModelOptions.gpuDevice();
}

public OnnxModelOptions onnxModelOptions() { return onnxModelOptions; }

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package com.yahoo.vespa.model.container.component;

import com.yahoo.config.ModelReference;
import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
Expand Down Expand Up @@ -47,7 +48,7 @@ public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployStat
transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
model.registerOnnxModelCost(cluster);
model.registerOnnxModelCost(cluster, onnxModelOptions);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package com.yahoo.vespa.model.container.component;

import com.yahoo.config.ModelReference;
import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
Expand Down Expand Up @@ -55,7 +56,7 @@ public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployS
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
model.registerOnnxModelCost(cluster);
model.registerOnnxModelCost(cluster, onnxModelOptions);
}

private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package com.yahoo.vespa.model.container.component;

import com.yahoo.config.ModelReference;
import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
Expand Down Expand Up @@ -48,7 +49,7 @@ public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, Dep
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null);
model.registerOnnxModelCost(cluster);
model.registerOnnxModelCost(cluster, onnxModelOptions);
}

private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import com.yahoo.config.ModelReference;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.builder.xml.XmlHelper;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.path.Path;
Expand Down Expand Up @@ -54,10 +55,10 @@ static Model fromXml(DeployState ds, Element model) {
return new Model(ds, model.getTagName(), modelId, url, path);
}

void registerOnnxModelCost(ApplicationContainerCluster c) {
void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) {
var resolvedUrl = resolvedUrl().orElse(null);
if (file != null) c.onnxModelCost().registerModel(file);
else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl);
if (file != null) c.onnxModelCost().registerModel(file, onnxModelOptions);
else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl, onnxModelOptions);
}

String name() { return paramName; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains>
private final List<SearchCluster> searchClusters = new LinkedList<>();
private final Collection<String> schemasWithGlobalPhase;
private final boolean globalPhase;
private final ApplicationPackage app;

private QueryProfiles queryProfiles;
private SemanticRules semanticRules;
private PageTemplates pageTemplates;
private ApplicationPackage app;

public ContainerSearch(DeployState deployState, ApplicationContainerCluster cluster, SearchChains chains) {
super(chains);
Expand Down Expand Up @@ -102,7 +102,7 @@ private void initializeDispatchers(Collection<SearchCluster> searchClusters) {
if ( ! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) {
var onnxModels = documentDb.getDerivedConfiguration().getRankProfileList().getOnnxModels();
onnxModels.asMap().forEach(
(__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath())));
(__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions()));
owningCluster.addComponent(factory);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ private void addModelEvaluation(Element spec, ApplicationContainerCluster cluste
!container.getHostResource().realResources().gpuResources().isZero());
onnxModel.setGpuDevice(gpuDevice, hasGpu);
}
cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()));
cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions());
}

cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.vespa.model.application.validation;

import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.NullConfigModelRegistry;
import com.yahoo.config.model.api.ApplicationClusterEndpoint;
import com.yahoo.config.model.api.ContainerEndpoint;
import com.yahoo.config.model.api.OnnxModelCost;
import com.yahoo.config.model.api.OnnxModelOptions;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.config.model.provision.InMemoryProvisioner;
Expand Down Expand Up @@ -123,12 +122,20 @@ private static class ModelCostDummy implements OnnxModelCost, OnnxModelCost.Calc
@Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; }
@Override public long aggregatedModelCostInBytes() { return totalCost.get(); }
@Override public void registerModel(ApplicationFile path) {}
@Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {}

@Override
public void registerModel(URI uri) {
assertEquals("https://my/url/model.onnx", uri.toString());
totalCost.addAndGet(modelCost);
}

@Override
public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {
assertEquals("https://my/url/model.onnx", uri.toString());
totalCost.addAndGet(modelCost);
}

}

}

0 comments on commit 9d28a47

Please sign in to comment.