diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 10e298dba77..77a3a0d0fb0 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -151,6 +151,7 @@ class OnnxConfig(ExportConfig, ABC): "feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}), "fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-classification": OrderedDict({"logits": {0: "batch_size"}}), + "image-feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}), "image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}), "image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "image-to-image": OrderedDict( diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 8c87da6cc6f..2f5794eecdd 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -168,6 +168,7 @@ class TasksManager: "feature-extraction": "AutoModel", "fill-mask": "AutoModelForMaskedLM", "image-classification": "AutoModelForImageClassification", + "image-feature-extraction": "AutoModel", "image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"), "image-to-image": "AutoModelForImageToImage", "image-to-text": "AutoModelForVision2Seq", @@ -462,11 +463,13 @@ class TasksManager: ), "convnext": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "image-classification", onnx="ConvNextOnnxConfig", ), "convnextv2": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "image-classification", onnx="ConvNextV2OnnxConfig", ), @@ -483,6 +486,7 @@ class TasksManager: "data2vec-vision": supported_tasks_mapping( "feature-extraction", "image-classification", + "image-feature-extraction", # ONNX doesn't support `adaptive_avg_pool2d` yet # "semantic-segmentation", onnx="Data2VecVisionOnnxConfig", @@ -515,11 +519,16 @@ class TasksManager: tflite="DebertaV2TFLiteConfig", ), "deit": supported_tasks_mapping( - "feature-extraction", "image-classification", "masked-im", onnx="DeiTOnnxConfig" + "feature-extraction", + "image-feature-extraction", + "image-classification", + "masked-im", + onnx="DeiTOnnxConfig", ), "detr": supported_tasks_mapping( "feature-extraction", "object-detection", + "image-feature-extraction", "image-segmentation", onnx="DetrOnnxConfig", ), @@ -546,6 +555,7 @@ class TasksManager: ), "dpt": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "depth-estimation", onnx="DptOnnxConfig", ), @@ -602,6 +612,7 @@ class TasksManager: ), "glpn": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "depth-estimation", onnx="GlpnOnnxConfig", ), @@ -669,6 +680,7 @@ class TasksManager: ), "imagegpt": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "image-classification", onnx="ImageGPTOnnxConfig", ), @@ -700,7 +712,9 @@ class TasksManager: "token-classification", onnx="LiltOnnxConfig", ), - "levit": supported_tasks_mapping("feature-extraction", "image-classification", onnx="LevitOnnxConfig"), + "levit": supported_tasks_mapping( + "feature-extraction", "image-classification", "image-feature-extraction", onnx="LevitOnnxConfig" + ), "longt5": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", @@ -764,17 +778,19 @@ class TasksManager: "mobilevit": supported_tasks_mapping( "feature-extraction", "image-classification", + "image-feature-extraction", "image-segmentation", onnx="MobileViTOnnxConfig", ), "mobilenet-v1": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "image-classification", onnx="MobileNetV1OnnxConfig", ), "mobilenet-v2": supported_tasks_mapping( - "feature-extraction", "image-classification", + "image-feature-extraction", onnx="MobileNetV2OnnxConfig", ), "mpnet": supported_tasks_mapping( @@ -870,16 +886,22 @@ class TasksManager: ), "poolformer": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "image-classification", onnx="PoolFormerOnnxConfig", ), "regnet": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "image-classification", onnx="RegNetOnnxConfig", ), "resnet": supported_tasks_mapping( - "feature-extraction", "image-classification", onnx="ResNetOnnxConfig", tflite="ResNetTFLiteConfig" + "feature-extraction", + "image-classification", + "image-feature-extraction", + onnx="ResNetOnnxConfig", + tflite="ResNetTFLiteConfig", ), "roberta": supported_tasks_mapping( "feature-extraction", @@ -913,6 +935,7 @@ class TasksManager: "segformer": supported_tasks_mapping( "feature-extraction", "image-classification", + "image-feature-extraction", "image-segmentation", "semantic-segmentation", onnx="SegformerOnnxConfig", @@ -957,12 +980,14 @@ class TasksManager: ), "swin": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "image-classification", "masked-im", onnx="SwinOnnxConfig", ), "swin2sr": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "image-to-image", onnx="Swin2srOnnxConfig", ), @@ -975,6 +1000,7 @@ class TasksManager: ), "table-transformer": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "object-detection", onnx="TableTransformerOnnxConfig", ), @@ -1007,7 +1033,7 @@ class TasksManager: onnx="VisionEncoderDecoderOnnxConfig", ), "vit": supported_tasks_mapping( - "feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig" + "feature-extraction", "image-classification", "image-feature-extraction", "masked-im", onnx="ViTOnnxConfig" ), "wavlm": supported_tasks_mapping( "feature-extraction", @@ -1066,6 +1092,7 @@ class TasksManager: ), "yolos": supported_tasks_mapping( "feature-extraction", + "image-feature-extraction", "object-detection", onnx="YolosOnnxConfig", ),