Skip to content

Commit

Permalink
Use OpenCLIPOnnxConfig and add to slow tests
Browse files Browse the repository at this point in the history
  • Loading branch information
isaac-chung committed Oct 18, 2023
1 parent 71fe301 commit 57dc032
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 2 deletions.
21 changes: 21 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,27 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}



class OpenCLIPOnnxConfig(CLIPOnnxConfig):

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"logits_per_image": {0: "image_batch_size", 1: "text_batch_size"},
"logits_per_text": {0: "text_batch_size", 1: "image_batch_size"},
"text_embeds": {0: "text_batch_size"},
"image_embeds": {0: "image_batch_size"},
}

def rename_ambiguous_inputs(self, inputs):
# The input name in the model signature is `x, hence the export input name is updated.
model_inputs = {}
model_inputs["image"] = inputs["pixel_values"]
model_inputs["text"] = inputs["input_ids"]

return model_inputs


class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ class TasksManager:
),
"open-clip": supported_tasks_mapping(
"zero-shot-image-classification",
onnx="CLIPOnnxConfig",
onnx="OpenCLIPOnnxConfig",
),
"clip": supported_tasks_mapping(
"feature-extraction",
Expand Down
4 changes: 4 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,7 @@
"resnext101-32x8d": "timm/resnext101_32x8d.tv_in1k",
"resnext101-64x4d": "timm/resnext101_64x4d.c1_in1k",
}

PYTORCH_OPEN_CLIP_MODEL = {
"open-clip": "laion/CLIP-ViT-B-16-laion2B-s34B-b88K",
}
24 changes: 23 additions & 1 deletion tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
if is_torch_available():
from optimum.exporters.tasks import TasksManager

from ..exporters_utils import PYTORCH_EXPORT_MODELS_TINY, PYTORCH_STABLE_DIFFUSION_MODEL, PYTORCH_TIMM_MODEL
from ..exporters_utils import PYTORCH_EXPORT_MODELS_TINY, PYTORCH_STABLE_DIFFUSION_MODEL, PYTORCH_TIMM_MODEL, PYTORCH_OPEN_CLIP_MODEL


def _get_models_to_test(export_models_dict: Dict):
Expand Down Expand Up @@ -257,6 +257,28 @@ def test_exporters_cli_fp16_timm(
):
self._onnx_export(model_name, task, monolith, no_post_process, device="cuda", fp16=True)

@parameterized.expand(PYTORCH_OPEN_CLIP_MODEL.items())
@require_torch
@require_vision
def test_exporters_cli_pytorch_cpu_open_clip(self, model_type: str, model_name: str):
self._onnx_export(model_name, model_type)

@parameterized.expand(PYTORCH_OPEN_CLIP_MODEL.items())
@require_torch_gpu
@require_vision
@slow
@pytest.mark.run_slow
def test_exporters_cli_pytorch_gpu_open_clip(self, model_type: str, model_name: str):
self._onnx_export(model_name, model_type, device="cuda")

@parameterized.expand(PYTORCH_OPEN_CLIP_MODEL.items())
@require_torch_gpu
@require_vision
@slow
@pytest.mark.run_slow
def test_exporters_cli_fp16_open_clip(self, model_type: str, model_name: str):
self._onnx_export(model_name, model_type, device="cuda", fp16=True)

@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY))
@require_torch
@require_vision
Expand Down
52 changes: 52 additions & 0 deletions tests/exporters/onnx/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
PYTORCH_EXPORT_MODELS_TINY,
PYTORCH_STABLE_DIFFUSION_MODEL,
PYTORCH_TIMM_MODEL,
PYTORCH_OPEN_CLIP_MODEL,
TENSORFLOW_EXPORT_MODELS,
VALIDATE_EXPORT_ON_SHAPES_SLOW,
)
Expand Down Expand Up @@ -446,6 +447,57 @@ def test_pytorch_export_for_timm_on_cuda(
monolith=monolith,
)

@parameterized.expand(_get_models_to_test(PYTORCH_OPEN_CLIP_MODEL))
@require_torch
@require_vision
@pytest.mark.run_slow
@slow
def test_pytorch_export_for_open_clip_on_cpu(
self,
test_name,
name,
model_name,
task,
onnx_config_class_constructor,
monolith: bool,
):
self._onnx_export(
test_name,
name,
model_name,
task,
onnx_config_class_constructor,
shapes_to_validate=VALIDATE_EXPORT_ON_SHAPES_SLOW,
monolith=monolith,
)

@parameterized.expand(_get_models_to_test(PYTORCH_OPEN_CLIP_MODEL))
@require_torch
@require_vision
@require_torch_gpu
@slow
@pytest.mark.run_slow
@pytest.mark.gpu_test
def test_pytorch_export_for_open_clip_on_cuda(
self,
test_name,
name,
model_name,
task,
onnx_config_class_constructor,
monolith: bool,
):
self._onnx_export(
test_name,
name,
model_name,
task,
onnx_config_class_constructor,
device="cuda",
shapes_to_validate=VALIDATE_EXPORT_ON_SHAPES_SLOW,
monolith=monolith,
)


class CustomWhisperOnnxConfig(WhisperOnnxConfig):
@property
Expand Down

0 comments on commit 57dc032

Please sign in to comment.