Skip to content

Commit

Permalink
Add image feature extraction task (#1726)
Browse files Browse the repository at this point in the history
* add image feature extraction

* fix test
  • Loading branch information
fxmarty authored Feb 28, 2024
1 parent c7cc312 commit dfca3fd
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class OnnxConfig(ExportConfig, ABC):
"feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}),
"fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-classification": OrderedDict({"logits": {0: "batch_size"}}),
"image-feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}),
"image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
"image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-to-image": OrderedDict(
Expand Down
37 changes: 32 additions & 5 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class TasksManager:
"feature-extraction": "AutoModel",
"fill-mask": "AutoModelForMaskedLM",
"image-classification": "AutoModelForImageClassification",
"image-feature-extraction": "AutoModel",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": "AutoModelForVision2Seq",
Expand Down Expand Up @@ -462,11 +463,13 @@ class TasksManager:
),
"convnext": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="ConvNextOnnxConfig",
),
"convnextv2": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="ConvNextV2OnnxConfig",
),
Expand All @@ -483,6 +486,7 @@ class TasksManager:
"data2vec-vision": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-feature-extraction",
# ONNX doesn't support `adaptive_avg_pool2d` yet
# "semantic-segmentation",
onnx="Data2VecVisionOnnxConfig",
Expand Down Expand Up @@ -515,11 +519,16 @@ class TasksManager:
tflite="DebertaV2TFLiteConfig",
),
"deit": supported_tasks_mapping(
"feature-extraction", "image-classification", "masked-im", onnx="DeiTOnnxConfig"
"feature-extraction",
"image-feature-extraction",
"image-classification",
"masked-im",
onnx="DeiTOnnxConfig",
),
"detr": supported_tasks_mapping(
"feature-extraction",
"object-detection",
"image-feature-extraction",
"image-segmentation",
onnx="DetrOnnxConfig",
),
Expand All @@ -546,6 +555,7 @@ class TasksManager:
),
"dpt": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"depth-estimation",
onnx="DptOnnxConfig",
),
Expand Down Expand Up @@ -602,6 +612,7 @@ class TasksManager:
),
"glpn": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"depth-estimation",
onnx="GlpnOnnxConfig",
),
Expand Down Expand Up @@ -669,6 +680,7 @@ class TasksManager:
),
"imagegpt": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="ImageGPTOnnxConfig",
),
Expand Down Expand Up @@ -700,7 +712,9 @@ class TasksManager:
"token-classification",
onnx="LiltOnnxConfig",
),
"levit": supported_tasks_mapping("feature-extraction", "image-classification", onnx="LevitOnnxConfig"),
"levit": supported_tasks_mapping(
"feature-extraction", "image-classification", "image-feature-extraction", onnx="LevitOnnxConfig"
),
"longt5": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down Expand Up @@ -764,17 +778,19 @@ class TasksManager:
"mobilevit": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-feature-extraction",
"image-segmentation",
onnx="MobileViTOnnxConfig",
),
"mobilenet-v1": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="MobileNetV1OnnxConfig",
),
"mobilenet-v2": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-feature-extraction",
onnx="MobileNetV2OnnxConfig",
),
"mpnet": supported_tasks_mapping(
Expand Down Expand Up @@ -870,16 +886,22 @@ class TasksManager:
),
"poolformer": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="PoolFormerOnnxConfig",
),
"regnet": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
onnx="RegNetOnnxConfig",
),
"resnet": supported_tasks_mapping(
"feature-extraction", "image-classification", onnx="ResNetOnnxConfig", tflite="ResNetTFLiteConfig"
"feature-extraction",
"image-classification",
"image-feature-extraction",
onnx="ResNetOnnxConfig",
tflite="ResNetTFLiteConfig",
),
"roberta": supported_tasks_mapping(
"feature-extraction",
Expand Down Expand Up @@ -913,6 +935,7 @@ class TasksManager:
"segformer": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-feature-extraction",
"image-segmentation",
"semantic-segmentation",
onnx="SegformerOnnxConfig",
Expand Down Expand Up @@ -957,12 +980,14 @@ class TasksManager:
),
"swin": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-classification",
"masked-im",
onnx="SwinOnnxConfig",
),
"swin2sr": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"image-to-image",
onnx="Swin2srOnnxConfig",
),
Expand All @@ -975,6 +1000,7 @@ class TasksManager:
),
"table-transformer": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"object-detection",
onnx="TableTransformerOnnxConfig",
),
Expand Down Expand Up @@ -1007,7 +1033,7 @@ class TasksManager:
onnx="VisionEncoderDecoderOnnxConfig",
),
"vit": supported_tasks_mapping(
"feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig"
"feature-extraction", "image-classification", "image-feature-extraction", "masked-im", onnx="ViTOnnxConfig"
),
"wavlm": supported_tasks_mapping(
"feature-extraction",
Expand Down Expand Up @@ -1066,6 +1092,7 @@ class TasksManager:
),
"yolos": supported_tasks_mapping(
"feature-extraction",
"image-feature-extraction",
"object-detection",
onnx="YolosOnnxConfig",
),
Expand Down

0 comments on commit dfca3fd

Please sign in to comment.