diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 4c5a727a18..050f0597ae 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -2615,5 +2615,18 @@ def overwrite_shape_and_generate_input( class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig - DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. + + +class RTDetrOnnxConfig(ViTOnnxConfig): + # Export the operator 'aten::grid_sampler' to ONNX fails under opset 16. + # Support for this operator was added in version 16. + DEFAULT_ONNX_OPSET = 16 + ATOL_FOR_VALIDATION = 1e-5 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + "pixel_mask": {0: "batch_size"}, + } diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 7cb5a31d2d..4ffb63fd6a 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1074,6 +1074,10 @@ class TasksManager: onnx="RoFormerOnnxConfig", tflite="RoFormerTFLiteConfig", ), + "rt-detr": supported_tasks_mapping( + "object-detection", + onnx="RTDetrOnnxConfig", + ), "sam": supported_tasks_mapping( "feature-extraction", onnx="SamOnnxConfig", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 900b5f3b5c..6fdffd132f 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -152,6 +152,7 @@ "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", + "rt-detr": "PekingU/rtdetr_r18vd", "sam": "fxmarty/sam-vit-tiny-random", "segformer": "hf-internal-testing/tiny-random-SegformerModel", "siglip": "hf-internal-testing/tiny-random-SiglipModel", @@ -280,6 +281,7 @@ "resnet": "microsoft/resnet-50", "roberta": "roberta-base", "roformer": "junnyu/roformer_chinese_base", + "rt-detr": "PekingU/rtdetr_r101vd", "sam": "facebook/sam-vit-base", "segformer": "nvidia/segformer-b0-finetuned-ade-512-512", "siglip": "google/siglip-base-patch16-224",