From 64f795d5b8b4f69eb761586392805054188573ed Mon Sep 17 00:00:00 2001 From: Nathan Azrak <42650258+nathan-az@users.noreply.github.com> Date: Fri, 1 Mar 2024 02:06:48 +1100 Subject: [PATCH] Adds MPNet to NormalizedConfig and ORTConfigManager (#1471) Adds mpnet to ORTConfigManager and NormalizedConfigManager Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> --- optimum/onnxruntime/utils.py | 1 + optimum/utils/normalized_config.py | 1 + tests/onnxruntime/test_modeling.py | 12 +++++++++++- tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index e269c8de718..6da38c7ea7a 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -122,6 +122,7 @@ class ORTConfigManager: "marian": "bart", "mbart": "bart", "mistral": "gpt2", + "mpnet": "bert", "mt5": "bart", "m2m-100": "bart", "nystromformer": "bert", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index f77978985d7..2f705aed2c4 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -241,6 +241,7 @@ class NormalizedConfigManager: "mbart": BartLikeNormalizedTextConfig, "mistral": NormalizedTextConfigWithGQA, "mixtral": NormalizedTextConfigWithGQA, + "mpnet": NormalizedTextConfig, "mpt": MPTNormalizedTextConfig, "mt5": T5LikeNormalizedTextConfig, "m2m-100": BartLikeNormalizedTextConfig, diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 6615a634df3..aefa12b8c4a 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1435,6 +1435,7 @@ class ORTModelForMaskedLMIntegrationTest(ORTModelTestMixin): "flaubert", "ibert", "mobilebert", + "mpnet", "perceiver_text", "roberta", "roformer", @@ -1976,7 +1977,16 @@ def test_compare_to_io_binding(self, model_arch): class ORTModelForFeatureExtractionIntegrationTest(ORTModelTestMixin): - SUPPORTED_ARCHITECTURES = ["albert", "bert", "camembert", "distilbert", "electra", "roberta", "xlm_roberta"] + SUPPORTED_ARCHITECTURES = [ + "albert", + "bert", + "camembert", + "distilbert", + "electra", + "mpnet", + "roberta", + "xlm_roberta", + ] FULL_GRID = {"model_arch": SUPPORTED_ARCHITECTURES} ORTMODEL_CLASS = ORTModelForFeatureExtraction diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index d444dde6ae4..965ce2c27d8 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -120,6 +120,7 @@ "mobilenet_v1": "google/mobilenet_v1_0.75_192", "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", "mobilevit": "hf-internal-testing/tiny-random-mobilevit", + "mpnet": "hf-internal-testing/tiny-random-MPNetModel", "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mt5": "lewtun/tiny-random-mt5", "nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",