Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move check_dummy_inputs_allowed to common export utils #2114

Merged
merged 9 commits into from
Dec 19, 2024
27 changes: 2 additions & 25 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from inspect import signature
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import onnx
Expand All @@ -45,6 +45,7 @@
from ...utils.save_utils import maybe_save_preprocessors
from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError
from ..tasks import TasksManager
from ..utils import check_dummy_inputs_are_allowed
from .base import OnnxConfig
from .constants import UNPICKABLE_ARCHS
from .model_configs import SpeechT5OnnxConfig
Expand Down Expand Up @@ -75,30 +76,6 @@ class DynamicAxisNameError(ValueError):
pass


def check_dummy_inputs_are_allowed(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str]
):
"""
Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`.
Args:
model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]):
The model instance.
model_inputs (`Iterable[str]`):
The model input names.
"""

forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call
forward_parameters = signature(forward).parameters
forward_inputs_set = set(forward_parameters.keys())
dummy_input_names = set(dummy_input_names)

# We are fine if config_inputs has more keys than model_inputs
if not dummy_input_names.issubset(forward_inputs_set):
raise ValueError(
f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}"
)


def validate_models_outputs(
models_and_onnx_configs: Dict[
str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from packaging import version
from transformers.utils import is_tf_available

from ...onnx import merge_decoders
from ...utils import (
DEFAULT_DUMMY_SHAPES,
BloomDummyPastKeyValuesGenerator,
Expand Down Expand Up @@ -1875,6 +1874,7 @@ def post_process_exported_models(
decoder_with_past_path = Path(path, onnx_files_subpaths[3])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
try:
from ...onnx import merge_decoders
# The decoder with past does not output the cross attention past key values as they are constant,
# hence the need for strict=False
merge_decoders(
Expand Down
27 changes: 26 additions & 1 deletion optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"""Utilities for model preparation to export."""

import copy
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from packaging import version
Expand Down Expand Up @@ -675,3 +676,27 @@ def _get_submodels_and_export_configs(
export_config = next(iter(models_and_export_configs.values()))[1]

return export_config, models_and_export_configs


def check_dummy_inputs_are_allowed(
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], dummy_input_names: Iterable[str]
):
"""
Checks that the dummy inputs from the ONNX config is a subset of the allowed inputs for `model`.
Args:
model (`Union[transformers.PreTrainedModel, transformers.TFPreTrainedModel`]):
The model instance.
model_inputs (`Iterable[str]`):
The model input names.
"""

forward = model.forward if is_torch_available() and isinstance(model, nn.Module) else model.call
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
forward_parameters = signature(forward).parameters
forward_inputs_set = set(forward_parameters.keys())
dummy_input_names = set(dummy_input_names)

# We are fine if config_inputs has more keys than model_inputs
if not dummy_input_names.issubset(forward_inputs_set):
raise ValueError(
f"Config dummy inputs are not a subset of the model inputs: {dummy_input_names} vs {forward_inputs_set}"
)
Loading