Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpeechT5 ONNX support #1404

Merged
merged 18 commits into from
Oct 18, 2023
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Supported architectures:
- SEW
- SEW-D
- Speech2Text
- SpeechT5
- Splinter
- SqueezeBert
- Stable Diffusion
Expand Down
7 changes: 7 additions & 0 deletions optimum/commands/export/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Defines the command line for the export with ONNX."""

import argparse
import json
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -143,6 +144,11 @@ def parse_args_onnx(parser):
"Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum."
),
)
optional_group.add_argument(
"--model-kwargs",
type=json.loads,
help=("Any kwargs passed to the model forward, or used to customize the export for a given model."),
)

input_group = parser.add_argument_group(
"Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)."
Expand Down Expand Up @@ -256,5 +262,6 @@ def run(self):
_variant=self.args.variant,
library_name=self.args.library_name,
no_position_ids=self.args.no_position_ids,
model_kwargs=self.args.model_kwargs,
**input_shapes,
)
42 changes: 35 additions & 7 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import is_torch_available

from ...commands.export.onnx import parse_args_onnx
Expand All @@ -38,6 +38,7 @@
get_decoder_models_for_export,
get_encoder_decoder_models_for_export,
get_sam_models_for_export,
get_speecht5_models_for_export,
get_stable_diffusion_models_for_export,
)

Expand Down Expand Up @@ -69,6 +70,7 @@ def _get_submodels_and_onnx_configs(
fn_get_submodels: Optional[Callable] = None,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
model_kwargs: Optional[Dict] = None,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
Expand All @@ -95,10 +97,11 @@ def _get_submodels_and_onnx_configs(

onnx_config.variant = _variant
all_variants = "\n".join(
[f"\t- {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
[f" - {name}: {description}" for name, description in onnx_config.VARIANTS.items()]
)
logger.info(f"Using the export variant {onnx_config.variant}. Available variants are:\n{all_variants}")

# TODO: this succession of if/else strongly suggests a refactor is needed.
if (
model.config.is_encoder_decoder
and task.startswith(TasksManager._ENCODER_DECODER_TASKS)
Expand All @@ -109,6 +112,8 @@ def _get_submodels_and_onnx_configs(
models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config)
elif model.config.model_type == "sam":
models_and_onnx_configs = get_sam_models_for_export(model, onnx_config)
elif model.config.model_type == "speecht5":
models_and_onnx_configs = get_speecht5_models_for_export(model, onnx_config, model_kwargs)
else:
models_and_onnx_configs = {"model": (model, onnx_config)}

Expand Down Expand Up @@ -333,6 +338,30 @@ def main_export(
f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
)

if library_name == "transformers":
config = AutoConfig.from_pretrained(
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)
model_type = config.model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fxmarty this line currently does nothing since it is set to False again in line 381. Do you want to have a look?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I'll fix

elif task not in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx"):
if original_task == "auto":
autodetected_message = " (auto-detected)"
else:
autodetected_message = ""
model_tasks = TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx")
raise ValueError(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum ONNX exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

model = TasksManager.get_model_from_task(
task,
model_name_or_path,
Expand Down Expand Up @@ -361,18 +390,16 @@ def main_export(
if not is_stable_diffusion:
if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE:
raise ValueError(
f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. "
f"{model_type} is not supported yet. Only {list(TasksManager._SUPPORTED_CLI_MODEL_TYPE.keys())} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task(
task, exporter="onnx"
):
if model.config.model_type.replace("_", "-") not in TasksManager._SUPPORTED_MODEL_TYPE:
custom_architecture = True

# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_onnx_configs is None:
raise ValueError(
f"Trying to export a {model.config.model_type.replace('-', '_')} model, that is a custom or unsupported architecture for the task {task}, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. For the task {task}, the Optimum ONNX exporter supports natively the architectures: {TasksManager.get_supported_model_type_for_task(task, exporter='onnx')}."
f"Trying to export a {model.config.model_type} model, that is a custom or unsupported architecture for the task {task}, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model.config.model_type} to be supported natively in the ONNX export."
)

if custom_architecture and original_task == "auto":
Expand Down Expand Up @@ -425,6 +452,7 @@ def main_export(
preprocessors=preprocessors,
_variant=_variant,
no_position_ids=no_position_ids,
model_kwargs=model_kwargs,
)

if not is_stable_diffusion:
Expand Down
23 changes: 14 additions & 9 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class OnnxConfig(ExportConfig, ABC):
MIN_TRANSFORMERS_VERSION = GLOBAL_MIN_TRANSFORMERS_VERSION
PATCHING_SPECS: Optional[List["PatchingSpec"]] = None
VARIANTS = {"default": "The default ONNX variant."}
DEFAULT_VARIANT = "default"
_TASK_TO_COMMON_OUTPUTS = {
"audio-classification": OrderedDict({"logits": {0: "batch_size"}}),
"audio-frame-classification": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
Expand Down Expand Up @@ -200,17 +201,14 @@ def __init__(
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
if task not in self._TASK_TO_COMMON_OUTPUTS:
raise ValueError(
f"{task} is not a supported task, supported tasks: {', '.join(self._TASK_TO_COMMON_OUTPUTS.keys())}"
)
self.task = task
self.int_dtype = int_dtype
self.float_dtype = float_dtype

self._config = config
self._preprocessors = preprocessors
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.variant = "default"

def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]:
"""
Expand Down Expand Up @@ -808,7 +806,8 @@ def with_behavior(
"""
if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior):
behavior = ConfigBehavior(behavior)
return self.__class__(

onnx_config = self.__class__(
self._config,
task=self.task,
int_dtype=self.int_dtype,
Expand All @@ -818,6 +817,8 @@ def with_behavior(
behavior=behavior,
preprocessors=self._preprocessors,
)
onnx_config.variant = self.variant
return onnx_config

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
Expand Down Expand Up @@ -902,8 +903,8 @@ def post_process_exported_models(
path, models_and_onnx_configs, onnx_files_subpaths
)

# Attempt to merge only if the decoder was exported without/with past
if self.use_past is True and len(models_and_onnx_configs) == 3:
# Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task
if len(onnx_files_subpaths) >= 3 and self.use_past is True or self.variant == "with-past":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to check self.variant ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. I'll need to double check.

decoder_path = Path(path, onnx_files_subpaths[1])
decoder_with_past_path = Path(path, onnx_files_subpaths[2])
decoder_merged_path = Path(path, ONNX_DECODER_MERGED_NAME + ".onnx")
Expand All @@ -922,7 +923,8 @@ def post_process_exported_models(
# In order to do the validation of the two branches on the same file
encoder_path = onnx_files_subpaths[0]

onnx_files_subpaths = [encoder_path, decoder_merged_path.name, decoder_merged_path.name]
onnx_files_subpaths_new = [encoder_path, decoder_merged_path.name, decoder_merged_path.name]
onnx_files_subpaths_new.extend(onnx_files_subpaths[3:])

# We validate the two branches of the decoder model then
models_and_onnx_configs[ONNX_DECODER_NAME][1].is_merged = True
Expand All @@ -933,8 +935,10 @@ def post_process_exported_models(

models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].use_cache_branch = True
models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1].is_merged = True
else:
onnx_files_subpaths_new = onnx_files_subpaths

return models_and_onnx_configs, onnx_files_subpaths
return models_and_onnx_configs, onnx_files_subpaths_new

def generate_dummy_inputs_for_validation(
self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None
Expand Down Expand Up @@ -1006,6 +1010,7 @@ def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: st
self.float_dtype = float_dtype
self._normalized_config = self._onnx_config._normalized_config
self.PATCHING_SPECS = self._onnx_config.PATCHING_SPECS
self.variant = "default"

@classmethod
def from_onnx_config(cls, config: OnnxConfig) -> "OnnxConfigWithLoss":
Expand Down
19 changes: 10 additions & 9 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from ..error_utils import AtolError, MinimumVersionError, OutputMatchError, ShapeError
from .base import OnnxConfig
from .model_configs import SpeechT5OnnxConfig
from .utils import PickableInferenceSession, recursive_to_device


Expand Down Expand Up @@ -142,15 +143,13 @@ def validate_models_outputs(
if use_subprocess:
logger.info("Validating models in subprocesses...")
exceptions = [] # run all validations before raising
onnx_paths = []
for i, model_name in enumerate(models_and_onnx_configs.keys()):
submodel, sub_onnx_config = models_and_onnx_configs[model_name]
onnx_model_path = (
output_dir.joinpath(onnx_files_subpaths[i])
if onnx_files_subpaths is not None
else output_dir.joinpath(model_name + ".onnx")
)
onnx_paths.append(onnx_model_path)
try:
# Model validation is done in subprocesses, as ONNX Runtime has the bad habit of
# not releasing memory once an InferenceSession is initialized.
Expand All @@ -168,12 +167,12 @@ def validate_models_outputs(
model_kwargs=model_kwargs,
)
except Exception as e:
exceptions.append(e)
exceptions.append((onnx_model_path, e))

if len(exceptions) != 0:
for i, exception in enumerate(exceptions[:-1]):
logger.error(f"Validation {i} for the model {onnx_paths[i].as_posix()} raised: {exception}")
raise exceptions[-1]
logger.error(f"Validation for the model {exception[0].as_posix()} raised: {exception[1]}")
raise exceptions[-1][1]


def validate_model_outputs(
Expand Down Expand Up @@ -423,9 +422,11 @@ def _run_validation(

if value_failures:
msg = "\n".join(f"- {t[0]}: max diff = {t[1]}" for t in value_failures)
raise AtolError(
f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}"
)
atol_msg = f"The maximum absolute difference between the output of the reference model and the ONNX exported model is not within the set tolerance {atol}:\n{msg}"

if isinstance(config, SpeechT5OnnxConfig):
atol_msg += "\nIMPORTANT NOTE: SpeechT5 uses a dropout at inference and the output validation of ONNX Runtime inference vs PyTorch is expected to fail. Reference: https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L727"
raise AtolError(atol_msg)


class ValidationProcess(mp.Process):
Expand Down Expand Up @@ -526,7 +527,7 @@ def export_pytorch(

with torch.no_grad():
model.config.return_dict = True
model.eval()
model = model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
Expand Down
Loading