diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json index d9c68c89189e..78b32d8af7b9 100644 --- a/config-model-api/abi-spec.json +++ b/config-model-api/abi-spec.json @@ -1453,7 +1453,9 @@ "methods" : [ "public abstract long aggregatedModelCostInBytes()", "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile)", - "public abstract void registerModel(java.net.URI)" + "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)", + "public abstract void registerModel(java.net.URI)", + "public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)" ], "fields" : [ ] }, @@ -1471,7 +1473,9 @@ "public com.yahoo.config.model.api.OnnxModelCost$Calculator newCalculator(com.yahoo.config.application.api.ApplicationPackage, com.yahoo.config.provision.ApplicationId)", "public long aggregatedModelCostInBytes()", "public void registerModel(com.yahoo.config.application.api.ApplicationFile)", - "public void registerModel(java.net.URI)" + "public void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)", + "public void registerModel(java.net.URI)", + "public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)" ], "fields" : [ ] }, @@ -1489,6 +1493,51 @@ ], "fields" : [ ] }, + "com.yahoo.config.model.api.OnnxModelOptions$GpuDevice" : { + "superClass" : "java.lang.Record", + "interfaces" : [ ], + "attributes" : [ + "public", + "final", + "record" + ], + "methods" : [ + "public void (int, boolean)", + "public void (int)", + "public final java.lang.String toString()", + "public final int hashCode()", + "public final boolean equals(java.lang.Object)", + "public int deviceNumber()", + "public boolean required()" + ], + "fields" : [ ] + }, + "com.yahoo.config.model.api.OnnxModelOptions" : { + "superClass" : "java.lang.Record", + "interfaces" : [ ], + "attributes" : [ + "public", + "final", + "record" + ], + "methods" : [ + "public void (java.lang.String, int, int, com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)", + "public void (java.util.Optional, java.util.Optional, java.util.Optional, java.util.Optional)", + "public static com.yahoo.config.model.api.OnnxModelOptions empty()", + "public com.yahoo.config.model.api.OnnxModelOptions withExecutionMode(java.lang.String)", + "public com.yahoo.config.model.api.OnnxModelOptions withInterOpThreads(java.lang.Integer)", + "public com.yahoo.config.model.api.OnnxModelOptions withIntraOpThreads(java.lang.Integer)", + "public com.yahoo.config.model.api.OnnxModelOptions withGpuDevice(com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)", + "public final java.lang.String toString()", + "public final int hashCode()", + "public final boolean equals(java.lang.Object)", + "public java.util.Optional executionMode()", + "public java.util.Optional interOpThreads()", + "public java.util.Optional intraOpThreads()", + "public java.util.Optional gpuDevice()" + ], + "fields" : [ ] + }, "com.yahoo.config.model.api.PortInfo" : { "superClass" : "java.lang.Object", "interfaces" : [ ], diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java index acb880704821..b98667457e42 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -10,6 +10,7 @@ /** * @author bjorncs */ +// TODO: Rename public interface OnnxModelCost { Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId); @@ -17,7 +18,9 @@ public interface OnnxModelCost { interface Calculator { long aggregatedModelCostInBytes(); void registerModel(ApplicationFile path); + void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions); void registerModel(URI uri); + void registerModel(URI uri, OnnxModelOptions onnxModelOptions); } static OnnxModelCost disabled() { return new DisabledOnnxModelCost(); } @@ -26,7 +29,9 @@ class DisabledOnnxModelCost implements OnnxModelCost, Calculator { @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; } @Override public long aggregatedModelCostInBytes() {return 0;} @Override public void registerModel(ApplicationFile path) {} + @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} @Override public void registerModel(URI uri) {} + @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {} } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java similarity index 84% rename from config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java rename to config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java index 6347f0dc4278..92817baae3fd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java @@ -1,5 +1,5 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.component; +package com.yahoo.config.model.api; import java.util.Optional; @@ -12,7 +12,11 @@ public record OnnxModelOptions(Optional executionMode, Optional interOpThreads, Optional intraOpThreads, Optional gpuDevice) { - public static OnnxModelOptions empty() { + public OnnxModelOptions(String executionMode, int interOpThreads, int intraOpThreads, GpuDevice gpuDevice) { + this(Optional.of(executionMode), Optional.of(interOpThreads), Optional.of(intraOpThreads), Optional.of(gpuDevice)); + } + + public static OnnxModelOptions empty() { return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); } diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 867ffdb3960c..9456baafd578 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,9 +1,9 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; -import com.yahoo.vespa.model.container.component.OnnxModelOptions; import com.yahoo.vespa.model.ml.OnnxModelInfo; import java.util.Collections; @@ -171,4 +171,6 @@ public Optional getGpuDevice() { return onnxModelOptions.gpuDevice(); } + public OnnxModelOptions onnxModelOptions() { return onnxModelOptions; } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index ea3caadc23ad..67fb720b8c09 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -47,7 +48,7 @@ public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployStat transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } @Override diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index cbae50b400c0..d22e6afc3d1a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -55,7 +56,7 @@ public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployS transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); transformerOutput = getChildValue(xml, "transformer-output").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } private static ModelReference resolveDefaultVocab(Model model, DeployState state) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index d1bd0dce0006..d98c72ab3a4c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -48,7 +49,7 @@ public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, Dep transformerOutput = getChildValue(xml, "transformer-output").orElse(null); normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } private static ModelReference resolveDefaultVocab(Model model, DeployState state) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java index c5daf23d6f8d..0d350242fd09 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java @@ -4,6 +4,7 @@ import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.builder.xml.XmlHelper; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.path.Path; @@ -54,10 +55,10 @@ static Model fromXml(DeployState ds, Element model) { return new Model(ds, model.getTagName(), modelId, url, path); } - void registerOnnxModelCost(ApplicationContainerCluster c) { + void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) { var resolvedUrl = resolvedUrl().orElse(null); - if (file != null) c.onnxModelCost().registerModel(file); - else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl); + if (file != null) c.onnxModelCost().registerModel(file, onnxModelOptions); + else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl, onnxModelOptions); } String name() { return paramName; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java index d86d117f1d2c..31468c05b997 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java @@ -52,11 +52,11 @@ public class ContainerSearch extends ContainerSubsystem private final List searchClusters = new LinkedList<>(); private final Collection schemasWithGlobalPhase; private final boolean globalPhase; + private final ApplicationPackage app; private QueryProfiles queryProfiles; private SemanticRules semanticRules; private PageTemplates pageTemplates; - private ApplicationPackage app; public ContainerSearch(DeployState deployState, ApplicationContainerCluster cluster, SearchChains chains) { super(chains); @@ -102,7 +102,7 @@ private void initializeDispatchers(Collection searchClusters) { if ( ! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) { var onnxModels = documentDb.getDerivedConfiguration().getRankProfileList().getOnnxModels(); onnxModels.asMap().forEach( - (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()))); + (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions())); owningCluster.addComponent(factory); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 18020f5df5d6..5ffd34c65574 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -800,7 +800,7 @@ private void addModelEvaluation(Element spec, ApplicationContainerCluster cluste !container.getHostResource().realResources().gpuResources().isZero()); onnxModel.setGpuDevice(gpuDevice, hasGpu); } - cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath())); + cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions()); } cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models)); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java index 8531aff3b1a6..9cadf5cffd80 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java @@ -1,14 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.model.application.validation; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.ApplicationClusterEndpoint; import com.yahoo.config.model.api.ContainerEndpoint; import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.provision.InMemoryProvisioner; @@ -123,12 +122,20 @@ private static class ModelCostDummy implements OnnxModelCost, OnnxModelCost.Calc @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; } @Override public long aggregatedModelCostInBytes() { return totalCost.get(); } @Override public void registerModel(ApplicationFile path) {} + @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} @Override public void registerModel(URI uri) { assertEquals("https://my/url/model.onnx", uri.toString()); totalCost.addAndGet(modelCost); } + + @Override + public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) { + assertEquals("https://my/url/model.onnx", uri.toString()); + totalCost.addAndGet(modelCost); + } + } }