diff --git a/.github/workflows/test_onnxruntime_slow.yml b/.github/workflows/test_onnxruntime_slow.yml new file mode 100644 index 00000000000..20371f79150 --- /dev/null +++ b/.github/workflows/test_onnxruntime_slow.yml @@ -0,0 +1,33 @@ +name: ONNX Runtime slow / Python - Test + +on: + workflow_dispatch: + schedule: + - cron: 0 7 * * * # every day at 7am + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + strategy: + fail-fast: false + matrix: + python-version: [3.8, 3.9] + os: [ubuntu-20.04] + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies for export + run: | + pip install .[tests,onnxruntime] + - name: Test with unittest + working-directory: tests + run: | + RUN_SLOW=1 pytest onnxruntime -s -m "run_slow" --durations=0 diff --git a/optimum/commands/export/onnx.py b/optimum/commands/export/onnx.py index d496f6f0392..85661ccf6cf 100644 --- a/optimum/commands/export/onnx.py +++ b/optimum/commands/export/onnx.py @@ -136,14 +136,6 @@ def parse_args_onnx(parser): default=None, help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library"), ) - optional_group.add_argument( - "--no-position-ids", - action="store_true", - help=( - "Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum." - ), - ) - input_group = parser.add_argument_group( "Input shapes (if necessary, this allows to override the shapes of the input given to the ONNX exporter, that requires an example input)." ) @@ -217,6 +209,14 @@ def parse_args_onnx(parser): default=DEFAULT_DUMMY_SHAPES["nb_points_per_image"], help="For Segment Anything. It corresponds to the number of points per segmentation masks.", ) + optional_group.add_argument( + "--legacy", + action="store_true", + help=( + "Export decoder only models in three files (without + with past and the resulting merged model)." + "Also disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum." + ), + ) # deprecated argument parser.add_argument("--for-ort", action="store_true", help=argparse.SUPPRESS) @@ -255,6 +255,6 @@ def run(self): use_subprocess=True, _variant=self.args.variant, library_name=self.args.library_name, - no_position_ids=self.args.no_position_ids, + legacy=self.args.legacy, **input_shapes, ) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 16a18afc552..1b601cdfb8d 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -68,7 +68,7 @@ def _get_submodels_and_onnx_configs( float_dtype: str = "fp32", fn_get_submodels: Optional[Callable] = None, preprocessors: Optional[List[Any]] = None, - no_position_ids: bool = False, + legacy: bool = False, ): is_stable_diffusion = "stable-diffusion" in task if not custom_architecture: @@ -82,8 +82,8 @@ def _get_submodels_and_onnx_configs( model=model, exporter="onnx", task=task ) onnx_config_kwargs = {} - if task.startswith("text-generation") and no_position_ids: - onnx_config_kwargs["no_position_ids"] = no_position_ids + if task.startswith("text-generation") and legacy: + onnx_config_kwargs["no_position_ids"] = legacy onnx_config = onnx_config_constructor( model.config, @@ -106,7 +106,7 @@ def _get_submodels_and_onnx_configs( ): models_and_onnx_configs = get_encoder_decoder_models_for_export(model, onnx_config) elif task.startswith("text-generation") and not monolith: - models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config) + models_and_onnx_configs = get_decoder_models_for_export(model, onnx_config, legacy=legacy) elif model.config.model_type == "sam": models_and_onnx_configs = get_sam_models_for_export(model, onnx_config) else: @@ -184,7 +184,7 @@ def main_export( use_subprocess: bool = False, _variant: str = "default", library_name: Optional[str] = None, - no_position_ids: bool = False, + legacy: bool = False, **kwargs_shapes, ): """ @@ -264,8 +264,8 @@ def main_export( library_name (`Optional[str]`, defaults to `None`): The library of the model(`"tansformers"` or `"diffusers"` or `"timm"`). If not provided, will attempt to automatically detect the library name for the checkpoint. - no_position_ids (`bool`, defaults to `False`): - Disable the use of position_ids for text-generation models that require it for batched generation. This argument is introduced for backward compatibility and will be removed in a future release of Optimum. + legacy (`bool`, defaults to `False`): + Disable the use of position_ids for text-generation models that require it for batched generation. Also enable to export decoder only models in three files (without + with past and the merged model). This argument is introduced for backward compatibility and will be removed in a future release of Optimum. **kwargs_shapes (`Dict`): Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export. @@ -353,9 +353,9 @@ def main_export( is_stable_diffusion = "stable-diffusion" in task model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") - if no_position_ids and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"): + if legacy and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"): logger.warning( - f"no_position_ids=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `no_position_ids=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381" + f"legacy=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `legacy=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381" ) if not is_stable_diffusion: @@ -424,7 +424,7 @@ def main_export( fn_get_submodels=fn_get_submodels, preprocessors=preprocessors, _variant=_variant, - no_position_ids=no_position_ids, + legacy=legacy, ) if not is_stable_diffusion: @@ -610,6 +610,7 @@ def main(): pad_token_id=args.pad_token_id, for_ort=args.for_ort, library_name=args.library_name, + legacy=args.legacy, **input_shapes, ) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 1e2ae99955c..a65374346ac 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -585,7 +585,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: elif self.task == "feature-extraction": common_outputs = OrderedDict({"last_hidden_state": {0: "batch_size"}}) else: - common_outputs = OrderedDict({"logits": {0: "batch_size"}}) + common_outputs = OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}) if self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 9259ad853da..3aca641513c 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -92,7 +92,7 @@ def __init__( @property def inputs(self) -> Dict[str, Dict[int, str]]: if self.use_past_in_inputs: - common_inputs = {"input_ids": {0: "batch_size"}} + common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}} self.add_past_key_values(common_inputs, direction="inputs") common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} else: @@ -164,10 +164,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # generating wrong position_ids in the model itself: # https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802 if not self.no_position_ids and self.task == "text-generation": - if self.use_past_in_inputs: - common_inputs["position_ids"] = {0: "batch_size"} - else: - common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} + common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} return common_inputs diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 73308e24a5d..a83c8a91fa5 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -56,7 +56,15 @@ TextSeq2SeqOnnxConfig, VisionOnnxConfig, ) -from .model_patcher import SAMModelPatcher, WavLMModelPatcher +from .model_patcher import ( + BartModelPatcher, + BloomModelPatcher, + LlamaModelPatcher, + MistralModelPatcher, + OPTModelPatcher, + SAMModelPatcher, + WavLMModelPatcher, +) if TYPE_CHECKING: @@ -216,6 +224,11 @@ class OPTOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return OPTModelPatcher(self, model, model_kwargs=model_kwargs) + class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) @@ -223,6 +236,11 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return LlamaModelPatcher(self, model, model_kwargs=model_kwargs) + class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 @@ -233,6 +251,11 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MistralModelPatcher(self, model, model_kwargs=model_kwargs) + class MPTOnnxConfig(TextDecoderOnnxConfig): # MPT does not require position_ids input. @@ -241,6 +264,11 @@ class MPTOnnxConfig(TextDecoderOnnxConfig): num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers" ) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return BloomModelPatcher(self, model, model_kwargs=model_kwargs) + class BloomOnnxConfig(TextDecoderOnnxConfig): # Bloom does not require position_ids input. @@ -274,6 +302,11 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire 1: decoder_sequence_name, } + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return BloomModelPatcher(self, model, model_kwargs=model_kwargs) + class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( @@ -413,7 +446,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return int_tensor -class BartOnnxConfig(TextSeq2SeqOnnxConfig): +class M2M100OnnxConfig(TextSeq2SeqOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( encoder_num_layers="encoder_layers", decoder_num_layers="decoder_layers", @@ -537,11 +570,14 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): ) -class MBartOnnxConfig(BartOnnxConfig): - pass +class BartOnnxConfig(M2M100OnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return BartModelPatcher(self, model, model_kwargs=model_kwargs) -class M2M100OnnxConfig(BartOnnxConfig): +class MBartOnnxConfig(BartOnnxConfig): pass diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index e6b50b6dc08..aa14526bd8c 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -19,6 +19,12 @@ from transformers.utils import is_torch_available +from ...utils.modeling_utils import ( + _prepare_attn_mask, + _prepare_decoder_attention_mask, + _prepare_decoder_sliding_window_attention_mask, +) + if is_torch_available(): import torch @@ -342,3 +348,103 @@ def patched_forward( return {"iou_scores": iou_predictions, "pred_masks": low_res_masks} self.patched_forward = patched_forward + + +class CausalAttentionMaskModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + self.patch = self.real_config.task == "text-generation" and self.real_config.use_past + + def __enter__(self): + super().__enter__() + if self.patch: + setattr(self._model_to_patch, self._orig_func_name, self._patch_func.__get__(self._model_to_patch)) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if self.patch: + setattr(self._model_to_patch, self._orig_func_name, self._orig_func.__get__(self._model_to_patch)) + + +class BloomModelPatcher(CausalAttentionMaskModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + if self.patch: + self._model_to_patch = model.transformer + self._patch_func = _prepare_attn_mask + self._orig_func_name = "_prepare_attn_mask" + self._orig_func = self._model_to_patch._prepare_attn_mask + + +class OPTModelPatcher(CausalAttentionMaskModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + if self.patch: + self._model_to_patch = model.model.decoder + self._patch_func = _prepare_decoder_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + self._orig_func = self._model_to_patch._prepare_decoder_attention_mask + + +class LlamaModelPatcher(CausalAttentionMaskModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + if self.patch: + self._model_to_patch = model.model + self._patch_func = _prepare_decoder_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + self._orig_func = self._model_to_patch._prepare_decoder_attention_mask + + +class MistralModelPatcher(CausalAttentionMaskModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + if self.patch: + self._model_to_patch = model.model + self._patch_func = _prepare_decoder_sliding_window_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + self._orig_func = self._model_to_patch._prepare_decoder_attention_mask + + +class BartModelPatcher(CausalAttentionMaskModelPatcher, Seq2SeqModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + if self.patch: + self._model_to_patch = model.model.decoder + self._patch_func = _prepare_decoder_attention_mask + self._orig_func_name = "_prepare_decoder_attention_mask" + self._orig_func = self._model_to_patch._prepare_decoder_attention_mask diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 6e90fc617fb..2dda5594a66 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -29,6 +29,7 @@ logging, ) from ...utils.import_utils import _diffusers_version +from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask # noqa: F401 from ..tasks import TasksManager from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME @@ -159,15 +160,16 @@ def _get_submodels_for_export_stable_diffusion( def _get_submodels_for_export_decoder( - model: Union["PreTrainedModel", "TFPreTrainedModel"], use_past: bool + model: Union["PreTrainedModel", "TFPreTrainedModel"], + use_past: bool, + legacy: bool = False, ) -> Dict[str, Union["PreTrainedModel", "TFPreTrainedModel"]]: """ Returns the decoder part of the model. """ - models_for_export = {} + models_for_export = {ONNX_DECODER_NAME if legacy else "model": model} - models_for_export[ONNX_DECODER_NAME] = model - if use_past: + if legacy and use_past: models_for_export[ONNX_DECODER_WITH_PAST_NAME] = model return models_for_export @@ -227,6 +229,7 @@ def get_encoder_decoder_models_for_export( def get_decoder_models_for_export( model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "OnnxConfig", + legacy: bool = False, ) -> Dict[str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel"], "OnnxConfig"]]: """ Returns two versions of the decoder that can be used together to perform fast generation: @@ -246,31 +249,42 @@ def get_decoder_models_for_export( `Dict[str, Tuple[Union[PreTrainedModel, TFPreTrainedModel], OnnxConfig]]: A Dict containing the model and onnx configs for the encoder and decoder parts of the model. """ - models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past) + + models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy) onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype} if model.config.model_type.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: onnx_kwargs["no_position_ids"] = config.no_position_ids - onnx_config = config.__class__( - model.config, - use_past=config.use_past, - use_past_in_inputs=False, - **onnx_kwargs, - ) - models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config) - - if config.use_past: - onnx_config_with_past = config.__class__( + if legacy: + onnx_config = config.__class__( model.config, - use_past=True, - use_past_in_inputs=True, + use_past=config.use_past, + use_past_in_inputs=False, **onnx_kwargs, ) - models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( - models_for_export[ONNX_DECODER_WITH_PAST_NAME], - onnx_config_with_past, + models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], onnx_config) + + if config.use_past: + onnx_config_with_past = config.__class__( + model.config, + use_past=True, + use_past_in_inputs=True, + **onnx_kwargs, + ) + models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( + models_for_export[ONNX_DECODER_WITH_PAST_NAME], + onnx_config_with_past, + ) + + else: + onnx_config = config.__class__( + model.config, + use_past=config.use_past, + use_past_in_inputs=config.use_past, + **onnx_kwargs, ) + models_for_export["model"] = (models_for_export["model"], onnx_config) return models_for_export diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 7a5c6364fe2..2707c6eeab2 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -14,42 +14,34 @@ """Classes handling causal-lm related architectures in ONNX Runtime.""" import logging -import shutil from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +import numpy as np +import onnx import torch -from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError +from onnx.tools import update_model_dims from transformers import AutoModelForCausalLM, GenerationConfig from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions +from transformers.modeling_outputs import CausalLMOutputWithPast import onnxruntime from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export -from ..onnx.utils import _get_external_data_paths -from ..utils import check_if_transformers_greater -from ..utils.file_utils import validate_file_exists -from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors -from .base import ORTDecoder +from ..onnx.utils import check_model_uses_external_data +from ..utils import NormalizedConfigManager, check_if_transformers_greater +from ..utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST +from ..utils.save_utils import maybe_save_preprocessors from .constants import DECODER_MERGED_ONNX_FILE_PATTERN, DECODER_ONNX_FILE_PATTERN, DECODER_WITH_PAST_ONNX_FILE_PATTERN from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel from .models.bloom import bloom_convert_to_bloom_cache, bloom_convert_to_standard_cache -from .utils import ( - ONNX_DECODER_NAME, - ONNX_DECODER_WITH_PAST_NAME, - get_provider_for_device, - parse_device, - validate_provider_availability, -) +from .utils import MULTI_QUERY_ATTN_MODELS, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_WEIGHTS_NAME if TYPE_CHECKING: from transformers import PretrainedConfig - if check_if_transformers_greater("4.25.0"): from transformers.generation import GenerationMixin else: @@ -119,220 +111,293 @@ """ -class ORTModelDecoder(ORTModel): +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTModelForCausalLM(ORTModel, GenerationMixin): """ - Base class for implementing models with a causal language modeling head using ONNX Runtime inference. + ONNX model with a causal language modeling head for ONNX Runtime inference. This class officially supports bloom, codegen, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama. """ + auto_model_class = AutoModelForCausalLM + main_input_name = "input_ids" + def __init__( self, - decoder_session: onnxruntime.InferenceSession, + model: onnxruntime.InferenceSession, config: "PretrainedConfig", - onnx_paths: List[str], - decoder_with_past_session: Optional[onnxruntime.InferenceSession] = None, - use_cache: bool = True, use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, preprocessors: Optional[List] = None, generation_config: Optional[GenerationConfig] = None, + use_cache: Optional[bool] = None, **kwargs, ): - """ - Args: - decoder_session (`onnxruntime.InferenceSession`): - The ONNX Runtime inference session associated to the decoder. - config ([`~transformers.PretrainedConfig`]): - An instance of the configuration associated to the model. Initializing with a config file does - not load the weights associated with the model, only the configuration. - decoder_with_past_session (`Optional[onnxruntime.InferenceSession]`, defaults to `None`): - The ONNX Runtime inference session associated to the decoder with past key values. This argument should not - be set if use_merged=True is used. - onnx_paths (`List[str]`): - Path to ONNX files associated with the model. - use_cache (`bool`, defaults to `True`): - Whether or not past key/values cache should be used. Defaults to `True`. - use_io_binding (`Optional[bool]`, defaults to `None`): - Whether to use IOBinding during inference to avoid memory copy between the host and devices. Defaults to - `True` if the execution provider is CPUExecutionProvider or CUDAExecutionProvider, otherwise defaults to `False`. - model_save_dir (`Optional[Union[str, Path, TemporaryDirectory]]`, defaults to `""`): - The directory under which the model exported to ONNX was saved. - preprocessors (`Optional[List]`, defaults to `None`): - The list of the preprocessors (tokenizer, processor, feature_extractor) to save alongside the ORTModel. - generation_config (`Optional[GenerationConfig]`, defaults to `None`): - The generation configuration used by default when calling `generate()`. - Refer to https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate. - """ if use_io_binding is None: - if decoder_session.get_providers()[0] in ["CPUExecutionProvider", "CUDAExecutionProvider"]: - use_io_binding = True - else: - use_io_binding = False + use_io_binding = model.get_providers()[0] in ["CPUExecutionProvider", "CUDAExecutionProvider"] - self.shared_attributes_init( - decoder_session, - use_io_binding=use_io_binding, - model_save_dir=model_save_dir, - ) - self.config = config + super().__init__(model, config, use_io_binding, model_save_dir, preprocessors, **kwargs) - # TODO: remove at version 2.0 - def show_deprecated_argument(arg_name): - if kwargs.pop(arg_name, None) is not None: - logger.warning( - f"The {arg_name} argument to create an {self.__class__.__name__} is deprecated, and not used " - "anymore." - ) + self.num_pkv = 2 + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) + self.key_value_input_names = [key for key in self.inputs_names if (".key" in key) or (".value" in key)] + self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)] + self.use_cache = len(self.key_value_input_names) > 0 - show_deprecated_argument("last_decoder_model_name") - show_deprecated_argument("last_decoder_with_past_model_name") - if kwargs: - raise ValueError( - f"{self.__class__.__name__} received {', '.join(kwargs.keys())}, but do not accept those arguments." - ) + if generation_config is None: + generation_config = GenerationConfig.from_model_config(config) + self.generation_config = generation_config + self.onnx_paths = [self.model_path] + self.use_merged = "use_cache_branch" in self.inputs_names - if use_cache is True: - # Auto-detect whether the provided session is a merged non-past / with-past or not - # TODO: make __init__ private and pass `use_merged` as an argument - use_merged = "use_cache_branch" in [inp.name for inp in decoder_session.get_inputs()] + self.use_fp16 = False + for inp in model.get_inputs(): + if inp.name == "past_key_values" and inp.type == "tensor(float16)": + self.use_fp16 = True + break - if use_merged is True and decoder_with_past_session is not None: - raise ValueError( - "Detected a merged decoder, but decoder_with_past_session was provided." - "Please only set decoder_session, or provide a non-merged decoder_session." - ) - if use_cache is True and use_merged is False and decoder_with_past_session is None: - raise ValueError( - "The parameter use_cache was set as True, but neither decoder_with_past_session was passed" - " nor a use_cache branch can be found in the decoder_session." - " Please pass a decoder_with_past_session or set use_cache=False." - ) - else: - use_merged = False + # Reference: https://github.com/huggingface/optimum/pull/1381 + model_type = config.model_type.replace("_", "-") + if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.inputs_names: + logger.warning( + f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. " + "We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support." + ) - if decoder_with_past_session is not None: - raise ValueError( - "The parameter decoder_with_past_session was passed, although use_cache is False." - "Please pass use_cache=True for decoder_with_past_session to be used." - ) + if use_cache ^ self.use_cache: + raise ValueError( + f"`use_cache` was set to `{use_cache}` but the loaded model only supports `use_cache={self.use_cache}`. " + f"Please load your current model with `use_cache={self.use_cache}` or export the original model " + f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. " + "To export your model, simply set `export=True`." + ) - if use_cache is False and use_io_binding is True: + if use_io_binding and not use_cache: raise ValueError( - "When using CUDAExecutionProvider, the parameters combination use_cache=False, use_io_binding=True" - " is not supported. Please either pass use_cache=True, use_io_binding=True (default)," - " or use_cache=False, use_io_binding=False." + "The parameters combination use_cache=False, use_io_binding=True is not supported. " + "Please either pass use_cache=True, use_io_binding=True (default), or use_cache=False, use_io_binding=False." ) - self.onnx_paths = onnx_paths - self.use_cache = use_cache - self.use_merged = use_merged - self.decoder = ORTDecoder(decoder_session, self) - self.decoder_model_path = Path(decoder_session._model_path) - self.decoder_model_name = self.decoder_model_path.name + @add_start_docstrings_to_model_forward( + CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") + + TEXT_GENERATION_EXAMPLE.format( + processor_class=_TOKENIZER_FOR_DOC, + model_class="ORTModelForCausalLM", + checkpoint="optimum/gpt2", + ) + ) + def forward( + self, + input_ids: torch.LongTensor, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache_branch: bool = None, + **kwargs, + ) -> CausalLMOutputWithPast: + # adding use_cache_branch in the signature here is just a hack for IO Binding + use_torch = isinstance(input_ids, torch.Tensor) + self.raise_on_numpy_input_io_binding(use_torch) + + inputs = {} + known_output_shapes = {} + use_cache_branch = None + loss = None + if self.use_cache: + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # Flatten the past_key_values (no need to flatten for models using multi-query attn) + if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + past_key_values = tuple( + past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer + ) - # Reference: https://github.com/huggingface/optimum/pull/1381 - model_type = config.model_type.replace("_", "-") - if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.decoder.input_names: - logger.warning( - f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although this input is required for batched generation for the architecture {model_type}. We strongly encourage to re-export the model with optimum>=1.14 for position_ids and batched inference support." + # Create dummy past_key_values for decoder first generation step if none given + use_cache_branch, past_key_values, known_output_shapes = self.prepare_past_key_values( + input_ids, past_key_values, use_torch ) - self.decoder_with_past = None - self.decoder_with_past_model_path = None - self.decoder_with_past_model_name = None - if self.use_cache is True and self.use_merged is False: - self.decoder_with_past = ORTDecoder(decoder_with_past_session, self) - self.decoder_with_past_model_path = Path(decoder_with_past_session._model_path) - self.decoder_with_past_model_name = self.decoder_with_past_model_path.name + if self.use_io_binding: + # TODO: fix transformers generate to have contiguous input_ids here already + # For an unknown reason, calling `contiguous()` here is necessary to not have errors + # on CPU EP with batch size > 1, despite it being also called in _prepare_io_binding. + # I suspect the reason is the contiguous python list that messes something up? + model_inputs = [input_ids.contiguous()] - if generation_config is None: - generation_config = GenerationConfig.from_model_config(config) - self.generation_config = generation_config + if "attention_mask" in self.inputs_names: + model_inputs.append(attention_mask) - @staticmethod - def _generate_regular_names_for_filename(filename: str): - name, extension = filename.rsplit(".", maxsplit=1) - return [ - filename, - f"{name}_quantized.{extension}", - f"{name}_optimized.{extension}", - f"{name}_merged.{extension}", - ] + if "position_ids" in self.inputs_names: + if position_ids is None: + raise ValueError("position_ids was not passed but is a required input for this ONNX model.") + model_inputs.append(position_ids.contiguous()) - @staticmethod - def load_model( - decoder_path: Union[str, Path], - decoder_with_past_path: Optional[Union[str, Path]] = None, - provider: str = "CPUExecutionProvider", - session_options: Optional[onnxruntime.SessionOptions] = None, - provider_options: Optional[Dict] = None, - ): - """ - Creates an instance of [`~optimum.onnxruntime.ORTModelDecoder`]. - Three inference sessions will be created for respectively the decoder and decoder with past key values - models. The default provider is `CPUExecutionProvider` to match the default behaviour in PyTorch/TensorFlow/JAX. - - Args: - decoder_path (`str` or `Path`): - The path of the decoder ONNX model. - decoder_with_past_path (`str` or `Path`, *optional*): - The path of the decoder with past key values ONNX model. - provider(`str`, *optional*, defaults to `"CPUExecutionProvider"`): - The ONNX Runtime provider to use for loading the model. - session_options (`Optional[onnxruntime.SessionOptions]`, *optional*),: - ONNX Runtime session options to use for loading the model. - provider_options (`Optional[Dict]`, *optional*): - Provider option dictionary corresponding to the provider used. See available options - for each provider: https://onnxruntime.ai/docs/api/c/group___global.html. - """ - decoder_session = ORTModel.load_model(decoder_path, provider, session_options, provider_options) - - decoder_with_past_session = None - # If a decoder_with_past_path is provided, an inference session for the decoder with past key/values as inputs - # will be enabled - if decoder_with_past_path is not None: - decoder_with_past_session = ORTModel.load_model( - decoder_with_past_path, provider, session_options, provider_options + if past_key_values is not None: + model_inputs += past_key_values + + if use_cache_branch is not None: + model_inputs.append(use_cache_branch) + + if "labels" in self.inputs_names: + model_inputs.append(labels) + known_output_shapes.update({"loss": []}) + + io_binding, output_shapes, output_buffers = self._prepare_io_binding( + self.model, + *model_inputs, + known_output_shapes=known_output_shapes, + ordered_input_names=self._ordered_input_names, ) - return decoder_session, decoder_with_past_session + if self.device.type == "cpu": + self.model.run_with_iobinding(io_binding) + else: + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + if self.use_cache: + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer(2) + past_key_values = () + for name in self.key_value_output_names: + past_key_values += (output_buffers[name].view(output_shapes[name]),) + + logits = output_buffers["logits"].view(output_shapes["logits"]) + + if "loss" in self.output_names: + loss = output_buffers["loss"].view(output_shapes["loss"]) + + else: + inputs["input_ids"] = input_ids.cpu().detach().numpy() if use_torch else input_ids + + if "attention_mask" in self.inputs_names: + inputs["attention_mask"] = attention_mask.cpu().detach().numpy() if use_torch else attention_mask + + if "labels" in self.inputs_names: + inputs["labels"] = labels.cpu().detach().numpy() if use_torch else labels - def _save_pretrained(self, save_directory: Union[str, Path]): - """ - Saves the model decoder and decoder with past key values as well as its configuration file to a - directory, so that it can be re-loaded using the - [`~optimum.onnxruntime.modeling_causal.ORTModelDecoder.from_pretrained`] class method. + if "position_ids" in self.inputs_names: + if position_ids is None: + raise ValueError("position_ids was not passed but is a required input for this ONNX model.") + inputs["position_ids"] = position_ids.cpu().detach().numpy() if use_torch else position_ids - Args: - save_directory (`str` or `Path`): - The directory where to save the model files. - """ - save_directory = Path(save_directory) - src_paths = [Path(path) for path in self.onnx_paths] - dst_paths = [save_directory / path.name for path in src_paths] + # Add the past_key_values to the decoder inputs + if past_key_values is not None: + for input_name, past_key_value in zip(self.key_value_input_names, past_key_values): + inputs[input_name] = past_key_value.cpu().detach().numpy() if use_torch else past_key_value - # add external data paths in case of large models - src_paths, dst_paths = _get_external_data_paths(src_paths, dst_paths) + if use_cache_branch is not None: + inputs["use_cache_branch"] = use_cache_branch.cpu().detach().numpy() if use_torch else use_cache_branch - for src_path, dst_path in zip(src_paths, dst_paths): - shutil.copyfile(src_path, dst_path) + outputs = self.model.run(None, inputs) - self.generation_config.save_pretrained(save_directory) + if self.use_cache: + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 for the self-attention) + past_key_values = tuple( + torch.from_numpy(outputs[self.output_names[key]]).to(self.device) + for key in self.key_value_output_names + ) + + logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device) + if "loss" in self.output_names: + loss = torch.from_numpy(outputs[self.output_names["loss"]]).to(self.device) + + if self.use_cache and self.config.model_type not in MULTI_QUERY_ATTN_MODELS: + # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and + # per decoder layer + past_key_values = tuple( + past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) + ) + + return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values) + + def prepare_past_key_values( + self, + input_ids: Union[None, torch.LongTensor, np.ndarray], + past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]], + use_torch: bool, + ): + sequence_length = input_ids.shape[1] + + constructor = torch if use_torch else np + if self.use_merged: + # Uses without/with branch of a merged decoder depending on whether real past key values are passed + use_cache_branch = constructor.full((1,), past_key_values is not None) + else: + # Uses separate decoders + use_cache_branch = None + + if use_torch and use_cache_branch is not None: + use_cache_branch = use_cache_branch.to(self.device) + + # Generate dummy past for the first forward if uses a merged decoder + if past_key_values is None: + batch_size = input_ids.shape[0] + if self.config.model_type in {"mistral", "llama"}: + num_attention_heads = self.normalized_config.num_key_value_heads + else: + num_attention_heads = self.normalized_config.num_attention_heads + embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads + dtype = constructor.float16 if self.use_fp16 else constructor.float32 + # TODO: find a way to better handle this controlflow + # "1" is the dummy sequence length + if self.config.model_type == "bloom": + shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head) + shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0) + key = constructor.zeros(shape_key, dtype=dtype) + value = constructor.zeros(shape_value, dtype=dtype) + + if use_torch: + key = key.to(self.device) + value = value.to(self.device) + + past_key_values = tuple( + key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value] + ) + elif self.config.model_type in MULTI_QUERY_ATTN_MODELS: + shape_key_and_value = (batch_size, 0, embed_size_per_head * 2) + key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype) + + if use_torch: + key_and_value = key_and_value.to(self.device) + + past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names))) + else: + shape = (batch_size, num_attention_heads, 0, embed_size_per_head) + key_or_value = constructor.zeros(shape, dtype=dtype) + + if use_torch: + key_or_value = key_or_value.to(self.device) + + past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) + + pkv_output_shape = {} + for name, value in zip(self.key_value_output_names, past_key_values): + shape = [*value.shape] + index = ( + 1 + if self.config.model_type in MULTI_QUERY_ATTN_MODELS + or (self.config.model_type == "bloom" and "value" in name) + else 2 + ) + + shape[index] += sequence_length + pkv_output_shape[name] = shape + + return use_cache_branch, past_key_values, pkv_output_shape @classmethod def _from_pretrained( cls, model_id: Union[str, Path], config: "PretrainedConfig", - init_cls: Type["ORTModelDecoder"], use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, - decoder_file_name: str = ONNX_DECODER_NAME, - decoder_with_past_file_name: str = ONNX_DECODER_WITH_PAST_NAME, + file_name: Optional[str] = None, subfolder: str = "", - local_files_only: bool = False, use_cache: bool = True, + local_files_only: bool = False, use_merged: Optional[bool] = None, provider: str = "CPUExecutionProvider", session_options: Optional[onnxruntime.SessionOptions] = None, @@ -340,7 +405,7 @@ def _from_pretrained( use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, **kwargs, - ): + ) -> "ORTModelForCausalLM": model_path = Path(model_id) # We do not implement the logic for use_cache=False, use_merged=True @@ -352,187 +417,137 @@ def _from_pretrained( ) use_merged = False - decoder_merged_path = None - # We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it), - # and use_merged = True (explicitely specified by the user) - if use_merged is not False: - try: - decoder_merged_path = ORTModelDecoder.infer_onnx_filename( + decoder_name = "decoder_file_name" if use_cache else "decoder_with_past_file_name" + decoder_file_name = kwargs.pop(decoder_name, None) + + if decoder_file_name is not None: + logger.warning(f"The `{decoder_name}` argument is deprecated, please use `file_name` instead.") + file_name = file_name or decoder_file_name + + if file_name is None: + decoder_path = None + # We use `is not False` here to include two cases: use_merged = None (in which case we auto-detect it), + # and use_merged = True (explicitely specified by the user) + if use_merged is not False: + try: + decoder_path = ORTModelForCausalLM.infer_onnx_filename( + model_id, + [DECODER_MERGED_ONNX_FILE_PATTERN], + argument_name=None, + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + use_merged = True + file_name = decoder_path.name + except FileNotFoundError as e: + if use_merged is True: + raise FileNotFoundError( + "The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()" + " but no ONNX file for a merged decoder could be found in" + f" {str(Path(model_id, subfolder))}, with the error: {e}" + ) + use_merged = False + + if use_merged is False: + pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN + # exclude decoder file for first iteration + decoder_path = ORTModelForCausalLM.infer_onnx_filename( model_id, - [DECODER_MERGED_ONNX_FILE_PATTERN], + [r"^((?!decoder).)*.onnx", pattern], argument_name=None, subfolder=subfolder, use_auth_token=use_auth_token, revision=revision, ) - use_merged = True - decoder_path = decoder_merged_path - except FileNotFoundError as e: - if use_merged is True: - raise FileNotFoundError( - "The parameter `use_merged=True` was passed to ORTModelForCausalLM.from_pretrained()" - " but no ONNX file for a merged decoder could be found in" - f" {str(Path(model_id, subfolder))}, with the error: {e}" - ) - use_merged = False + file_name = decoder_path.name - decoder_without_past_path = None - decoder_with_past_path = None - if use_merged is False: - if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision): - decoder_without_past_path = ORTModelDecoder.infer_onnx_filename( - model_id, - [DECODER_ONNX_FILE_PATTERN], - "decoder_file_name", - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, + if file_name == ONNX_DECODER_WITH_PAST_NAME and config.model_type in MODEL_TO_PATCH_FOR_PAST: + raise ValueError( + f"{ONNX_DECODER_WITH_PAST_NAME} not supported for the following architecture : {', '.join(MODEL_TO_PATCH_FOR_PAST)}. Please re-export your model or set use_cache=False." ) - else: - decoder_without_past_path = model_path / subfolder / decoder_file_name - decoder_path = decoder_without_past_path + regular_file_names = [] + for name in [ONNX_WEIGHTS_NAME, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME]: + regular_file_names += ORTModelForCausalLM._generate_regular_names_for_filename(name) - decoder_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(ONNX_DECODER_NAME) - if decoder_path.name not in decoder_regular_onnx_filenames: + if file_name not in regular_file_names: logger.warning( - f"The ONNX file {decoder_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_regular_onnx_filenames}, the " + f"The ONNX file {file_name} is not a regular name used in optimum.onnxruntime that are {regular_file_names}, the " f"{cls.__name__} might not behave as expected." ) - # If the decoder without / with past has been merged, we do not need to look for any additional file - if use_cache is True: - if not validate_file_exists( - model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision - ): - try: - decoder_with_past_path = ORTModelDecoder.infer_onnx_filename( - model_id, - [DECODER_WITH_PAST_ONNX_FILE_PATTERN], - "decoder_with_past_file_name", - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - ) - except FileNotFoundError as e: - raise FileNotFoundError( - "The parameter `use_cache=True` was passed to ORTModelForCausalLM.from_pretrained()" - " but no ONNX file using past key values could be found in" - f" {str(Path(model_id, subfolder))}, with the error: {e}" - ) - else: - decoder_with_past_path = model_path / subfolder / decoder_with_past_file_name - - decoder_with_past_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename( - ONNX_DECODER_WITH_PAST_NAME - ) - - if decoder_with_past_path.name not in decoder_with_past_regular_onnx_filenames: - logger.warning( - f"The ONNX file {decoder_with_past_path.name} is not a regular name used in optimum.onnxruntime that are {decoder_with_past_regular_onnx_filenames}, " - f"the {cls.__name__} might not behave as expected." - ) - - preprocessors = None - if model_path.is_dir(): - new_model_save_dir = model_path - preprocessors = maybe_load_preprocessors(model_id) + if config.model_type == "bloom": + init_cls = ORTBloomForCausalLM + elif config.model_type == "mpt": + init_cls = ORTMPTForCausalLM + elif config.model_type == "opt": + init_cls = ORTOPTForCausalLM else: - attribute_name_to_filename = { - "last_decoder_model_name": decoder_path.name if use_merged is False else None, - "last_decoder_with_past_model_name": decoder_with_past_path.name - if (use_cache is True and use_merged is False) - else None, - "last_decoder_merged_name": decoder_merged_path.name if use_merged is True else None, - } - paths = {} - for attr_name, filename in attribute_name_to_filename.items(): - if filename is None: - continue - model_cache_path = hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=filename, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) + init_cls = ORTModelForCausalLM - # try download external data - try: - hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=filename + "_data", - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - except EntryNotFoundError: - # model doesn't use external data - pass + model_cache_path, preprocessors = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + new_model_save_dir = model_cache_path.parent + + # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it + # instead of the path only. + if model_save_dir is None: + model_save_dir = new_model_save_dir - paths[attr_name] = Path(model_cache_path).name - new_model_save_dir = Path(model_cache_path).parent - preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + # Since v1.7.0 decoder with past models have fixed sequence length of 1 + # To keep these models compatible we set this dimension to dynamic + onnx_model = onnx.load(str(model_cache_path), load_external_data=False) + model_uses_external_data = check_model_uses_external_data(onnx_model) - if use_merged is True: - decoder_path = new_model_save_dir / paths["last_decoder_merged_name"] - decoder_merged_path = new_model_save_dir / paths["last_decoder_merged_name"] - else: - decoder_path = new_model_save_dir / paths["last_decoder_model_name"] - decoder_without_past_path = new_model_save_dir / paths["last_decoder_model_name"] + if model_uses_external_data: + onnx_model = onnx.load(str(model_cache_path), load_external_data=True) - if use_cache is True: - decoder_with_past_path = new_model_save_dir / paths["last_decoder_with_past_model_name"] + input_dims = { + node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.input + } + if input_dims["input_ids"][1] == 1: + input_dims["input_ids"][1] = "sequence_length" + output_dims = { + node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim] + for node in onnx_model.graph.output + } + output_dims["logits"][1] = "sequence_length" + onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims) + + onnx.save( + onnx_model, + str(model_cache_path), + save_as_external_data=model_uses_external_data, + all_tensors_to_one_file=True, + location=model_cache_path.name + "_data", + size_threshold=0, + ) + del onnx_model - ort_inference_sessions = cls.load_model( - decoder_path=decoder_path, - decoder_with_past_path=None if use_merged is True or use_cache is False else decoder_with_past_path, + model = ORTModel.load_model( + model_cache_path, provider=provider, session_options=session_options, provider_options=provider_options, ) - if model_save_dir is None: - model_save_dir = new_model_save_dir - - generation_config = None - try: - generation_config = GenerationConfig.from_pretrained( - model_id, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - ) - except OSError: - logger.info("Generation config file not found, using a generation config created from the model config.") - - onnx_paths = [] - if use_merged is False: - onnx_paths.append(decoder_without_past_path) - if use_cache is True: - onnx_paths.append(decoder_with_past_path) - else: - onnx_paths.append(decoder_merged_path) - return init_cls( - ort_inference_sessions[0], - config, - decoder_with_past_session=ort_inference_sessions[1], - use_cache=use_cache, + model=model, + config=config, use_io_binding=use_io_binding, model_save_dir=model_save_dir, preprocessors=preprocessors, - generation_config=generation_config, - onnx_paths=onnx_paths, + use_cache=use_cache, ) @classmethod @@ -554,19 +569,18 @@ def _from_transformers( provider_options: Optional[Dict[str, Any]] = None, use_io_binding: Optional[bool] = None, task: Optional[str] = None, - ) -> "ORTModelDecoder": + ) -> "ORTModelForCausalLM": + file_name = ONNX_WEIGHTS_NAME + + if use_merged: + logger.warning("The `use_merged` argument is deprecated when the model is exported, and not used anymore.") + use_merged = False + if task is None: task = cls._auto_model_to_task(cls.auto_model_class) - if use_cache is True: - task = task + "-with-past" - - if use_cache is False and use_merged is True: - raise ValueError( - "The incompatible arguments use_cache=False, use_merged=True were passed to ORTModelForCausalLM.from_pretrained()." - " Please pass either use_cache=False, use_merged=False to disable past key value caching, or use_cache=True, use_merged=False" - " to disable the merging of the decoder not using / using past key and value." - ) + if use_cache: + task += "-with-past" save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) @@ -576,7 +590,8 @@ def _from_transformers( output=save_dir_path, task=task, do_validation=False, - no_post_process=not use_merged, + no_post_process=False, + legacy=False, subfolder=subfolder, revision=revision, cache_dir=cache_dir, @@ -599,88 +614,7 @@ def _from_transformers( provider_options=provider_options, use_io_binding=use_io_binding, model_save_dir=save_dir, - ) - - def to(self, device: Union[torch.device, str, int]): - """ - Changes the ONNX Runtime provider according to the device. - - Args: - device (`Union[torch.device, str, int]`): - Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run - the model on the associated CUDA device id. You can pass native `torch.device` or a `str` too. - - Returns: - `ORTModel`: the model placed on the requested device. - """ - device, provider_options = parse_device(device) - - if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": - return self - - provider = get_provider_for_device(device) - validate_provider_availability(provider) # raise error if the provider is not available - self.device = device - self.decoder.session.set_providers([provider], provider_options=[provider_options]) - if self.decoder_with_past is not None: - self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options]) - self.providers = self.decoder.session.get_providers() - - return self - - -@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) -class ORTModelForCausalLM(ORTModelDecoder, GenerationMixin): - """ - ONNX model with a causal language modeling head for ONNX Runtime inference. This class officially supports bloom, codegen, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gptj, llama. - """ - - auto_model_class = AutoModelForCausalLM - main_input_name = "input_ids" - - @add_start_docstrings_to_model_forward( - CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length") - + TEXT_GENERATION_EXAMPLE.format( - processor_class=_TOKENIZER_FOR_DOC, - model_class="ORTModelForCausalLM", - checkpoint="optimum/gpt2", - ) - ) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - position_ids: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - **kwargs, - ) -> CausalLMOutputWithCrossAttentions: - if past_key_values is None or self.use_cache is False: - outputs = self.decoder( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - position_ids=position_ids, - labels=labels, - ) - elif self.use_merged is True: - outputs = self.decoder( - input_ids=input_ids[:, -1:], - past_key_values=past_key_values, - attention_mask=attention_mask, - position_ids=position_ids, - ) - else: - outputs = self.decoder_with_past( - input_ids=input_ids[:, -1:], - past_key_values=past_key_values, - attention_mask=attention_mask, - labels=labels, - position_ids=position_ids, - ) - - return CausalLMOutputWithCrossAttentions( - loss=outputs.get("loss", None), logits=outputs.logits, past_key_values=outputs.past_key_values + file_name=file_name, ) # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation @@ -718,24 +652,6 @@ def can_generate(self): """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate.""" return True - @classmethod - def _from_pretrained( - cls, - model_id: Union[str, Path], - config: "PretrainedConfig", - **kwargs, - ): - if config.model_type == "bloom": - init_cls = ORTBloomForCausalLM - elif config.model_type == "mpt": - init_cls = ORTMPTForCausalLM - elif config.model_type == "opt": - init_cls = ORTOPTForCausalLM - else: - init_cls = ORTModelForCausalLM - - return super()._from_pretrained(model_id, config, init_cls=init_cls, **kwargs) - class ORTBloomForCausalLM(ORTModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 46963745da4..b58a37eb43a 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -486,55 +486,30 @@ def _from_pretrained( "not behave as expected." ) - preprocessors = None - if model_path.is_dir(): - model = ORTModel.load_model( - model_path / file_name, - provider=provider, - session_options=session_options, - provider_options=provider_options, - ) - new_model_save_dir = model_path - preprocessors = maybe_load_preprocessors(model_id) - else: - model_cache_path = hf_hub_download( - repo_id=model_id, - filename=file_name, - subfolder=subfolder, - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - - # try download external data - try: - hf_hub_download( - repo_id=model_id, - subfolder=subfolder, - filename=file_name + "_data", - use_auth_token=use_auth_token, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - local_files_only=local_files_only, - ) - except EntryNotFoundError: - # model doesn't use external data - pass - - model = ORTModel.load_model( - model_cache_path, provider=provider, session_options=session_options, provider_options=provider_options - ) - new_model_save_dir = Path(model_cache_path).parent - preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder) + model_cache_path, preprocessors = cls._cached_file( + model_path=model_path, + use_auth_token=use_auth_token, + revision=revision, + force_download=force_download, + cache_dir=cache_dir, + file_name=file_name, + subfolder=subfolder, + local_files_only=local_files_only, + ) + new_model_save_dir = model_cache_path.parent # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it # instead of the path only. if model_save_dir is None: model_save_dir = new_model_save_dir + model = ORTModel.load_model( + model_cache_path, + provider=provider, + session_options=session_options, + provider_options=provider_options, + ) + return cls( model=model, config=config, @@ -753,13 +728,20 @@ def _prepare_io_binding( name = ordered_input_names[idx] tensor = tensor.contiguous() input_name_to_shape[name] = tensor.shape + + data_ptr = tensor.data_ptr() + if "past" in name and data_ptr == 0: + # During first generation, sequence_length can be 0 when use_cache=True, which results in data_ptr to also be 0. + # To keep compatibility with IO binding, we pass the data pointer of input_ids instead. This will have no impact because past_key_values will not be used during the first generation. + data_ptr = model_inputs[0].data_ptr() + io_binding.bind_input( name, tensor.device.type, IOBindingHelper.get_device_index(self.device), name_to_np_type[name], tuple(tensor.shape), - tensor.data_ptr(), + data_ptr, ) dimensions = {} for input_ in model.get_inputs(): @@ -821,6 +803,55 @@ def raise_on_numpy_input_io_binding(self, use_torch: bool): " with model.use_io_binding = False, or pass torch.Tensor inputs instead." ) + @staticmethod + def _cached_file( + model_path: Union[Path, str], + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + subfolder: str = "", + local_files_only: bool = False, + ): + model_path = Path(model_path) + + # locates a file in a local folder and repo, downloads and cache it if necessary. + if model_path.is_dir(): + model_cache_path = model_path / file_name + preprocessors = maybe_load_preprocessors(model_path.as_posix()) + else: + model_cache_path = hf_hub_download( + repo_id=model_path.as_posix(), + filename=file_name, + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + # try download external data + try: + hf_hub_download( + repo_id=model_path.as_posix(), + subfolder=subfolder, + filename=file_name + "_data", + use_auth_token=use_auth_token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + ) + except EntryNotFoundError: + # model doesn't use external data + pass + + model_cache_path = Path(model_cache_path) + preprocessors = maybe_load_preprocessors(model_path.as_posix(), subfolder=subfolder) + + return model_cache_path, preprocessors + FEATURE_EXTRACTION_EXAMPLE = r""" Example of feature extraction: diff --git a/optimum/onnxruntime/optimization.py b/optimum/onnxruntime/optimization.py index 2db9f753c34..9e62a3f324c 100644 --- a/optimum/onnxruntime/optimization.py +++ b/optimum/onnxruntime/optimization.py @@ -97,17 +97,11 @@ def from_pretrained( # Add the decoder with past key/values if present if model_or_path.use_cache: onnx_model_path.append(model_or_path.decoder_with_past_model_path) - elif isinstance(model_or_path, ORTModelForCausalLM): - if model_or_path.use_merged is True: - raise NotImplementedError( - "ORTOptimizer does not support ORTModelForCausalLM models that use a single ONNX for both the without/with past cases." - " Please pass an ORTModelForCausalLM that uses a separate ONNX for each without/with past cases. This can be done" - " by using `ORTModelForCausalLM.from_pretrained(..., export=True, use_merged=False)`, or by" - " using the option `--no-post-process` in the optimum-cli ONNX export tool." - ) - onnx_model_path.append(model_or_path.decoder_model_path) - if model_or_path.use_cache: - onnx_model_path.append(model_or_path.decoder_with_past_model_path) + elif isinstance(model_or_path, ORTModelForCausalLM) and model_or_path.use_merged: + raise NotImplementedError( + "ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged. " + "Please re-export your model. This can be done by using the optimum-cli ONNX export tool or `ORTModelForCausalLM.from_pretrained(..., export=True, use_merged=False)`." + ) else: onnx_model_path.append(model_or_path.model_path) config = model_or_path.config diff --git a/optimum/onnxruntime/quantization.py b/optimum/onnxruntime/quantization.py index 1c13bfb465f..d56e301c3cf 100644 --- a/optimum/onnxruntime/quantization.py +++ b/optimum/onnxruntime/quantization.py @@ -33,7 +33,6 @@ from ..utils.save_utils import maybe_save_preprocessors from . import ORTQuantizableOperator from .configuration import CalibrationConfig, ORTConfig, QuantizationConfig -from .modeling_decoder import ORTModelForCausalLM from .modeling_ort import ORTModel from .modeling_seq2seq import ORTModelForConditionalGeneration from .preprocessors import QuantizationPreprocessor @@ -136,13 +135,6 @@ def from_pretrained( path = None if isinstance(model_or_path, ORTModelForConditionalGeneration): raise NotImplementedError(ort_quantizer_error_message) - elif isinstance(model_or_path, ORTModelForCausalLM): - if model_or_path.use_cache is False: - path = Path(model_or_path.decoder_model_path) - elif model_or_path.use_cache is True and model_or_path.use_merged is False: - raise NotImplementedError(ort_quantizer_error_message) - else: - path = Path(model_or_path.decoder_model_path) elif isinstance(model_or_path, Path) and file_name is None: onnx_files = list(model_or_path.glob("*.onnx")) if len(onnx_files) == 0: diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 89f2f5598a6..67e12861eb5 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -13,6 +13,22 @@ # limitations under the License. import functools +from typing import Tuple + +import torch + + +MODEL_TO_PATCH_FOR_PAST = { + "bart", + "blenderbot", + "blenderbot-small", + "bloom", + "llama", + "mistral", + "mpt", + "opt", + "pegasus", +} def recurse_getattr(obj, attr: str): @@ -39,3 +55,126 @@ def recurse_setattr(module, name, value): else: name, rest = name.split(".", 1) recurse_setattr(getattr(module, name), rest, value) + + +# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, + device: torch.device, + past_key_values_length: int, + dtype: torch.dtype = torch.bool, +) -> torch.BoolTensor: + """ + Make causal mask used for bi-directional self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device) + seq_ids = torch.arange(target_length, device=device) + + mask[:, past_key_values_length:] = ( + (seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min + if torch.is_floating_point(mask) + else seq_ids[:, None] < seq_ids[None, :] + ) + + return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + + +# NOTE: For MODEL_TO_PATCH_FOR_PAST architectures, when exporting the model with an input of sequence length of 1, the attention masks will be generated incorrectly for other sequence length +# https://github.com/huggingface/transformers/blob/0ee45906845c8d58b9bd2df5acd90e09b00047ff/src/transformers/models/bloom/modeling_bloom.py#L654 +# The method taking care of the decoder mask generation of the models from these architectures must be patched during export for sequence length of 1. + + +# Modified from transformers.models.bloom.modeling_bloom._prepare_attn_mask +def _prepare_attn_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + past_key_values_length: int, +) -> torch.BoolTensor: + from transformers.models.bloom.modeling_bloom import _expand_mask + + # create causal mask + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + combined_attention_mask = _make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length + ) + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + +# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask +def _prepare_decoder_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + inputs_embeds: torch.Tensor, + past_key_values_length: int, +): + from transformers.models.llama.modeling_llama import _expand_mask + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + combined_attention_mask = _make_causal_mask( + input_shape, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + dtype=inputs_embeds.dtype, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask +def _prepare_decoder_sliding_window_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + inputs_embeds: torch.Tensor, + past_key_values_length: int, + sliding_window: int, +): + from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + + combined_attention_mask = _make_sliding_window_causal_mask( + input_shape, + device=inputs_embeds.device, + dtype=inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + sliding_window=sliding_window, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask diff --git a/setup.py b/setup.py index 7a7f4546844..f654e3a71bc 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,15 @@ ], "exporters": ["onnx", "onnxruntime", "timm"], "exporters-gpu": ["onnx", "onnxruntime-gpu", "timm"], - "exporters-tf": ["tensorflow>=2.4,<=2.12.1", "tf2onnx", "onnx", "onnxruntime", "timm", "h5py", "numpy<1.24.0"], + "exporters-tf": [ + "tensorflow>=2.4,<=2.12.1", + "tf2onnx", + "onnx", + "onnxruntime", + "timm", + "h5py", + "numpy<1.24.0", + ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.11.0", "openvino": "optimum-intel[openvino]>=1.11.0", diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index b1cdedbea84..efdbaba4235 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -19,6 +19,7 @@ from tempfile import TemporaryDirectory from typing import Dict, Optional +import onnx import pytest from parameterized import parameterized from transformers import AutoModelForSequenceClassification, AutoTokenizer, is_torch_available @@ -26,7 +27,12 @@ from optimum.exporters.error_utils import MinimumVersionError from optimum.exporters.onnx.__main__ import main_export -from optimum.onnxruntime import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME +from optimum.onnxruntime import ( + ONNX_DECODER_MERGED_NAME, + ONNX_DECODER_NAME, + ONNX_DECODER_WITH_PAST_NAME, + ONNX_ENCODER_NAME, +) from optimum.utils.testing_utils import require_diffusers, require_timm @@ -413,6 +419,21 @@ def test_stable_diffusion(self): check=True, ) + def test_legacy(self): + with TemporaryDirectory() as tmpdirname: + subprocess.run( + f"python3 -m optimum.exporters.onnx --model hf-internal-testing/tiny-random-gpt2 --task text-generation-with-past --legacy {tmpdirname}", + shell=True, + capture_output=True, + ) + folder_contents = os.listdir(tmpdirname) + self.assertIn(ONNX_DECODER_NAME, folder_contents) + self.assertIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents) + self.assertIn(ONNX_DECODER_MERGED_NAME, folder_contents) + + model = onnx.load(Path(tmpdirname) / ONNX_DECODER_MERGED_NAME) + self.assertNotIn("position_ids", {node.name for node in model.graph.input}) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_vision @require_torch_gpu diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 10eaeddd13c..11e6a53da36 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -14,6 +14,7 @@ # limitations under the License. import gc import os +from functools import partial from pathlib import Path from tempfile import TemporaryDirectory from typing import Dict @@ -529,8 +530,8 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} -def fn_get_submodels_custom(model): - return {"decoder_model": model, "decoder_with_past_model": model} +def fn_get_submodels_custom(model, legacy=False): + return {"decoder_model": model, "decoder_with_past_model": model} if legacy else {"model": model} class OnnxCustomExport(TestCase): @@ -572,7 +573,6 @@ def test_custom_export_official_model(self): def test_custom_export_trust_remote(self, fn_get_submodels): model_id = "fxmarty/tiny-mpt-random-remote-code" config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) - onnx_config = CustomMPTOnnxConfig( config=config, task="text-generation", @@ -581,22 +581,29 @@ def test_custom_export_trust_remote(self, fn_get_submodels): ) onnx_config_with_past = CustomMPTOnnxConfig(config, task="text-generation", use_past=True) - custom_onnx_configs = { - "decoder_model": onnx_config, - "decoder_with_past_model": onnx_config_with_past, - } + for legacy in (True, False): + if legacy: + custom_onnx_configs = { + "decoder_model": onnx_config, + "decoder_with_past_model": onnx_config_with_past, + } + else: + custom_onnx_configs = { + "model": onnx_config_with_past, + } - with TemporaryDirectory() as tmpdirname: - main_export( - model_id, - output=tmpdirname, - task="text-generation-with-past", - trust_remote_code=True, - custom_onnx_configs=custom_onnx_configs, - no_post_process=True, - fn_get_submodels=fn_get_submodels, - opset=14, - ) + with TemporaryDirectory() as tmpdirname: + main_export( + model_id, + output=tmpdirname, + task="text-generation-with-past", + trust_remote_code=True, + custom_onnx_configs=custom_onnx_configs, + no_post_process=True, + fn_get_submodels=partial(fn_get_submodels, legacy=legacy) if fn_get_submodels else None, + legacy=legacy, + opset=14, + ) def test_custom_export_trust_remote_error(self): model_id = "mohitsha/tiny-ernie-random-remote-code" diff --git a/tests/onnx/test_onnx_graph_transformations.py b/tests/onnx/test_onnx_graph_transformations.py index bed539eaccb..c06ac5af971 100644 --- a/tests/onnx/test_onnx_graph_transformations.py +++ b/tests/onnx/test_onnx_graph_transformations.py @@ -85,6 +85,7 @@ def test_merge_decoders(self, *args): tmpdir, task=task, no_post_process=True, + legacy=True, ) decoder = onnx.load(os.path.join(tmpdir, "decoder_model.onnx")) diff --git a/tests/onnxruntime/nightly_test_trainer.py b/tests/onnxruntime/nightly_test_trainer.py index 38bdfd07973..2eb3ca433f7 100644 --- a/tests/onnxruntime/nightly_test_trainer.py +++ b/tests/onnxruntime/nightly_test_trainer.py @@ -40,11 +40,7 @@ default_data_collator, is_torch_available, ) -from transformers.testing_utils import ( - require_deepspeed, - require_torch, - slow, -) +from transformers.testing_utils import require_deepspeed, require_torch, slow from transformers.training_args import OptimizerNames diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index e6868c2fa7d..b7695cbd651 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -59,7 +59,7 @@ ) from transformers.modeling_utils import no_init_weights from transformers.onnx.utils import get_preprocessor -from transformers.testing_utils import get_gpu_count, require_torch_gpu +from transformers.testing_utils import get_gpu_count, require_torch_gpu, slow from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin from optimum.exporters import TasksManager @@ -138,12 +138,12 @@ def __init__(self, *args, **kwargs): def test_load_model_from_local_path(self): model = ORTModel.from_pretrained(self.LOCAL_MODEL_PATH) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_load_model_from_hub(self): model = ORTModel.from_pretrained(self.ONNX_MODEL_ID) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_load_model_from_hub_subfolder(self): @@ -151,11 +151,11 @@ def test_load_model_from_hub_subfolder(self): model = ORTModelForSequenceClassification.from_pretrained( "fxmarty/tiny-bert-sst2-distilled-subfolder", subfolder="my_subfolder", export=True ) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) model = ORTModel.from_pretrained("fxmarty/tiny-bert-sst2-distilled-onnx-subfolder", subfolder="my_subfolder") - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_load_seq2seq_model_from_hub_subfolder(self): @@ -178,7 +178,7 @@ def test_load_model_from_cache(self): model = ORTModel.from_pretrained(self.TINY_ONNX_MODEL_ID, local_files_only=True) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_load_model_from_empty_cache(self): @@ -768,7 +768,7 @@ def test_stable_diffusion_model_on_gpu_str(self): @require_hf_token def test_load_model_from_hub_private(self): model = ORTModel.from_pretrained(self.ONNX_MODEL_ID, use_auth_token=os.environ.get("HF_AUTH_TOKEN", None)) - self.assertIsInstance(model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) self.assertIsInstance(model.config, PretrainedConfig) def test_save_model(self): @@ -832,11 +832,12 @@ def test_save_load_ort_model_with_external_data(self): os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @parameterized.expand([(False,), (True,)]) + @pytest.mark.run_slow + @slow def test_save_load_decoder_model_with_external_data(self, use_cache: bool): with tempfile.TemporaryDirectory() as tmpdirname: - os.environ["FORCE_ONNX_EXTERNAL_DATA"] = "1" # force exporting small model with external data model = ORTModelForCausalLM.from_pretrained( - MODEL_NAMES["gpt2"], + "gpt2-large", use_cache=use_cache, export=True, use_merged=False, @@ -846,18 +847,14 @@ def test_save_load_decoder_model_with_external_data(self, use_cache: bool): # verify external data is exported folder_contents = os.listdir(tmpdirname) - self.assertTrue(ONNX_DECODER_NAME in folder_contents) - self.assertTrue(ONNX_DECODER_NAME + "_data" in folder_contents) - - if use_cache: - self.assertTrue(ONNX_DECODER_WITH_PAST_NAME in folder_contents) - self.assertTrue(ONNX_DECODER_WITH_PAST_NAME + "_data" in folder_contents) + self.assertTrue(ONNX_WEIGHTS_NAME in folder_contents) + self.assertTrue(ONNX_WEIGHTS_NAME + "_data" in folder_contents) + self.assertFalse(use_cache ^ model.use_cache) # verify loading from local folder works model = ORTModelForCausalLM.from_pretrained( tmpdirname, use_cache=use_cache, export=False, use_io_binding=False ) - os.environ.pop("FORCE_ONNX_EXTERNAL_DATA") @parameterized.expand([(False,), (True,)]) def test_save_load_seq2seq_model_with_external_data(self, use_cache: bool): @@ -1103,7 +1100,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForQuestionAnswering.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1270,7 +1267,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForMaskedLM.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1432,7 +1429,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForSequenceClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1605,7 +1602,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForTokenClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1728,7 +1725,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForFeatureExtraction.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1873,7 +1870,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForMultipleChoice.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -1962,7 +1959,6 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): FULL_GRID = { "model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [False, True], - "use_merged": [False, True], } ORTMODEL_CLASS = ORTModelForCausalLM @@ -1971,27 +1967,37 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 - def test_inference_old_onnx_model(self): - model = ORTModelForCausalLM.from_pretrained("optimum/gpt2") - - tokenizer = get_preprocessor("optimum/gpt2") + @parameterized.expand([(False,), (True,)]) + def test_inference_old_onnx_model(self, use_cache): + model_id = "optimum/gpt2" + model = AutoModelForCausalLM.from_pretrained("gpt2") + tokenizer = get_preprocessor(model_id) text = "This is a sample output" tokens = tokenizer(text, return_tensors="pt") + onnx_model = ORTModelForCausalLM.from_pretrained(model_id, use_cache=use_cache, use_io_binding=use_cache) - model.generate(**tokens) + self.assertEqual(onnx_model.use_cache, use_cache) + self.assertEqual(onnx_model.model_path.name, ONNX_DECODER_WITH_PAST_NAME if use_cache else ONNX_DECODER_NAME) + outputs_onnx = onnx_model.generate( + **tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30 + ) + outputs = model.generate(**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30) + self.assertTrue(torch.allclose(outputs_onnx, outputs)) def test_load_model_from_hub_onnx(self): model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-without-merge") self.assertFalse(model.use_merged) self.assertTrue(model.use_cache) - self.assertTrue(model.decoder_with_past is not None) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) + self.assertEqual(model.onnx_paths[0].name, ONNX_DECODER_WITH_PAST_NAME) model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-with-merge") self.assertTrue(model.use_merged) self.assertTrue(model.use_cache) - self.assertTrue(model.decoder_with_past is None) + self.assertIsInstance(model.model, onnxruntime.InferenceSession) + self.assertEqual(model.onnx_paths[0].name, ONNX_DECODER_MERGED_NAME) def test_load_vanilla_transformers_which_is_not_supported(self): with self.assertRaises(Exception) as context: @@ -1999,24 +2005,6 @@ def test_load_vanilla_transformers_which_is_not_supported(self): self.assertIn("Unrecognized configuration class", str(context.exception)) - @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_merge_from_transformers_and_save(self, model_arch): - if "text-generation-with-past" not in TasksManager.get_supported_tasks_for_model_type( - model_arch.replace("_", "-"), exporter="onnx" - ): - self.skipTest("Unsupported -with-past export case") - - model_id = MODEL_NAMES[model_arch] - model = ORTModelForCausalLM.from_pretrained(model_id, export=True, use_merged=True) - with tempfile.TemporaryDirectory() as tmpdir: - model.save_pretrained(tmpdir) - save_path = os.path.join(tmpdir, ONNX_DECODER_MERGED_NAME) - self.assertTrue(has_onnx_input(save_path, "use_cache_branch")) - - folder_contents = os.listdir(tmpdir) - self.assertTrue(ONNX_DECODER_NAME not in folder_contents) - self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) - @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_onnx_and_save(self, model_arch): model_id = MODEL_NAMES[model_arch] @@ -2026,26 +2014,23 @@ def test_merge_from_onnx_and_save(self, model_arch): self.skipTest("Unsupported export case") with tempfile.TemporaryDirectory() as tmpdir: - main_export(model_id, tmpdir, task=task) + main_export(model_id, tmpdir, task=task, legacy=True) model = ORTModelForCausalLM.from_pretrained(tmpdir) self.assertTrue(model.use_merged) - self.assertTrue(model.decoder_with_past is None) - + self.assertIsInstance(model.model, onnxruntime.InferenceSession) model.save_pretrained(tmpdir + "_save") save_path = os.path.join(tmpdir + "_save", ONNX_DECODER_MERGED_NAME) self.assertTrue(has_onnx_input(save_path, "use_cache_branch")) folder_contents = os.listdir(tmpdir + "_save") - self.assertTrue(ONNX_DECODER_NAME not in folder_contents) - self.assertTrue(ONNX_DECODER_WITH_PAST_NAME not in folder_contents) + self.assertNotIn(ONNX_DECODER_NAME, folder_contents) + self.assertNotIn(ONNX_DECODER_WITH_PAST_NAME, folder_contents) + self.assertNotIn(ONNX_WEIGHTS_NAME, folder_contents) @parameterized.expand(grid_parameters(FULL_GRID)) - def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): - if use_cache is False and use_merged is True: - self.skipTest("use_cache=False, use_merged=True are uncompatible") - + def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool): use_io_binding = None if use_cache is False: use_io_binding = False @@ -2054,7 +2039,6 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach "test_name": test_name, "model_arch": model_arch, "use_cache": use_cache, - "use_merged": use_merged, } self._setup(model_args) @@ -2064,21 +2048,11 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach use_cache=use_cache, use_io_binding=use_io_binding, ) - if use_merged is False: - model_path = Path(self.onnx_model_dirs[test_name], ONNX_DECODER_NAME) - self.assertFalse(has_onnx_input(model_path, "use_cache_branch")) - self.assertEqual(onnx_model.use_merged, False) - else: - model_path = Path(self.onnx_model_dirs[test_name], ONNX_DECODER_MERGED_NAME) - self.assertTrue(has_onnx_input(model_path, "use_cache_branch")) - self.assertEqual(onnx_model.use_merged, True) - - self.assertIsInstance(onnx_model.decoder, ORTDecoder) - if onnx_model.use_cache is True and onnx_model.use_merged is False: - self.assertIsInstance(onnx_model.decoder_with_past, ORTDecoder) - if onnx_model.use_cache is True and onnx_model.use_merged is True: - self.assertTrue(onnx_model.decoder_with_past is None) + model_path = Path(self.onnx_model_dirs[test_name], ONNX_WEIGHTS_NAME) + self.assertFalse(has_onnx_input(model_path, "use_cache_branch")) + self.assertFalse(onnx_model.use_merged) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2122,10 +2096,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach gc.collect() @parameterized.expand(grid_parameters(FULL_GRID)) - def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool): - if use_cache is False and use_merged is True: - self.skipTest("use_cache=False, use_merged=True are uncompatible") - + def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bool): use_io_binding = None if use_cache is False: use_io_binding = False @@ -2134,7 +2105,6 @@ def test_pipeline_ort_model(self, test_name: str, model_arch: str, use_cache: bo "test_name": test_name, "model_arch": model_arch, "use_cache": use_cache, - "use_merged": use_merged, } self._setup(model_args) @@ -2284,18 +2254,10 @@ def test_compare_with_and_without_past_key_values(self, model_arch): @parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]})) def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, model_arch: str, use_cache: bool): - model_args = { - "test_name": test_name + "_True", - "model_arch": model_arch, - "use_cache": use_cache, - "use_merged": True, - } - self._setup(model_args) model_args = { "test_name": test_name + "_False", "model_arch": model_arch, "use_cache": use_cache, - "use_merged": False, } self._setup(model_args) @@ -2303,20 +2265,29 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode tokenizer = get_preprocessor(model_id) text = "My Name is Philipp and i live" tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None) - model_not_merged_dir = self.onnx_model_dirs[test_name + "_False"] - model_merged_dir = self.onnx_model_dirs[test_name + "_True"] - model_not_merged = ORTModelForCausalLM.from_pretrained(model_not_merged_dir) - not_merged_onnx_path = Path(model_not_merged_dir, ONNX_DECODER_NAME) + not_merged_onnx_path = Path(model_not_merged_dir, ONNX_WEIGHTS_NAME) self.assertFalse(has_onnx_input(not_merged_onnx_path, "use_cache_branch")) - self.assertEqual(model_not_merged.use_merged, False) + self.assertFalse(model_not_merged.use_merged) + + model_merged_dir = Path(model_not_merged_dir) / "merged" + task = model_not_merged.export_feature + if use_cache: + task += "-with-past" + + main_export( + model_id, + output=model_merged_dir, + task=task, + no_post_process=False, + legacy=True, + ) model_merged = ORTModelForCausalLM.from_pretrained(model_merged_dir) merged_onnx_path = Path(model_merged_dir, ONNX_DECODER_MERGED_NAME) self.assertTrue(has_onnx_input(merged_onnx_path, "use_cache_branch")) - self.assertEqual(model_merged.decoder_with_past, None) - self.assertEqual(model_merged.use_merged, True) + self.assertTrue(model_merged.use_merged) outputs_model_not_merged = model_not_merged.generate(**tokens) outputs_model_merged = model_merged.generate(**tokens) @@ -2435,7 +2406,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] if model_arch in MODEL_NAMES else self.ARCH_MODEL_MAP[model_arch] onnx_model = ORTModelForImageClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2575,7 +2546,7 @@ def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] onnx_model = ORTModelForSemanticSegmentation.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2730,7 +2701,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForAudioClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2882,7 +2853,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForCTC.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -2941,7 +2912,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForAudioXVector.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) @@ -3033,7 +3004,7 @@ def test_compare_to_transformers(self, model_arch): model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] onnx_model = ORTModelForAudioFrameClassification.from_pretrained(self.onnx_model_dirs[model_arch]) - self.assertIsInstance(onnx_model.model, onnxruntime.capi.onnxruntime_inference_collection.InferenceSession) + self.assertIsInstance(onnx_model.model, onnxruntime.InferenceSession) self.assertIsInstance(onnx_model.config, PretrainedConfig) set_seed(SEED) diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index 7e3e670edd0..cd2127ac3be 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -518,15 +518,7 @@ def _test_optimization_levels( ort_model = ORTModelForCausalLM.from_pretrained( self.onnx_model_dirs[export_name], use_cache=use_cache, provider=provider, use_io_binding=use_io_binding ) - - if use_merged: - with self.assertRaises(NotImplementedError) as cm: - optimizer = ORTOptimizer.from_pretrained(ort_model) - - self.assertTrue("ORTModelForCausalLM models that use a single ONNX" in str(cm.exception)) - self.skipTest("Unsupported optimization case") - else: - optimizer = ORTOptimizer.from_pretrained(ort_model) + optimizer = ORTOptimizer.from_pretrained(ort_model) if provider == "CUDAExecutionProvider": for_gpu = True @@ -541,7 +533,6 @@ def _test_optimization_levels( with tempfile.TemporaryDirectory(suffix="_optimized") as tmp_dir: optimizer.optimize(save_dir=tmp_dir, optimization_config=optimization_config) - optimized_model = ORTModelForCausalLM.from_pretrained( tmp_dir, use_cache=use_cache, provider=provider, use_io_binding=use_io_binding ) @@ -594,3 +585,15 @@ def test_optimization_levels_gpu( provider="CUDAExecutionProvider", use_io_binding=use_io_binding, ) + + def test_merged_optimization(self): + ort_model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-with-merge") + self.assertTrue(ort_model.use_cache) + + with self.assertRaises(NotImplementedError) as cm: + ORTOptimizer.from_pretrained(ort_model) + + self.assertTrue( + "ORTOptimizer does not support ORTModelForCausalLM models when without/with past models are merged" + in str(cm.exception) + ) diff --git a/tests/onnxruntime/test_quantization.py b/tests/onnxruntime/test_quantization.py index aff1b51b534..4062c556ea9 100644 --- a/tests/onnxruntime/test_quantization.py +++ b/tests/onnxruntime/test_quantization.py @@ -35,6 +35,7 @@ ORTQuantizer, QuantizationConfig, ) +from optimum.utils.testing_utils import grid_parameters class ORTQuantizerTest(unittest.TestCase): @@ -78,6 +79,10 @@ class ORTDynamicQuantizationTest(unittest.TestCase): (ORTModelForSequenceClassification, "hf-internal-testing/tiny-random-bart", 32), ) + SUPPORTED_DECODER_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( + (ORTModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 22), + ) + @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS) def test_dynamic_quantization(self, model_cls, model_name, expected_quantized_matmuls): qconfig = QuantizationConfig( @@ -96,11 +101,7 @@ def test_dynamic_quantization(self, model_cls, model_name, expected_quantized_ma model.save_pretrained(tmp_dir) quantizer = ORTQuantizer.from_pretrained(model) - quantizer.quantize( - save_dir=output_dir, - quantization_config=qconfig, - ) - + quantizer.quantize(save_dir=output_dir, quantization_config=qconfig) expected_ort_config = ORTConfig(quantization=qconfig) ort_config = ORTConfig.from_pretrained(tmp_dir) # Verify the ORTConfig was correctly created and saved @@ -119,19 +120,12 @@ def test_dynamic_quantization_subgraphs(self): qconfig = AutoQuantizationConfig.avx512(is_static=False, per_channel=True) tmp_dir = tempfile.mkdtemp() output_dir = Path(tmp_dir) - model = ORTModelForCausalLM.from_pretrained( - "hf-internal-testing/tiny-random-gpt2", export=True, use_merged=True - ) - + model = ORTModelForCausalLM.from_pretrained("fxmarty/onnx-tiny-random-gpt2-with-merge", use_merged=True) self.assertTrue(model.use_merged) model.save_pretrained(tmp_dir) quantizer = ORTQuantizer.from_pretrained(model) - quantizer.quantize( - save_dir=output_dir, - quantization_config=qconfig, - ) - + quantizer.quantize(save_dir=output_dir, quantization_config=qconfig) expected_ort_config = ORTConfig(quantization=qconfig) ort_config = ORTConfig.from_pretrained(tmp_dir) # Verify the ORTConfig was correctly created and saved @@ -146,6 +140,34 @@ def test_dynamic_quantization_subgraphs(self): self.assertTrue(num_quantized_matmul > 0) gc.collect() + @parameterized.expand( + grid_parameters( + {"model_arch": SUPPORTED_DECODER_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS, "use_cache": [True, False]} + ) + ) + def test_decoder_quantization_with_and_without_cache(self, test_name, model_info, use_cache): + model_cls, model_name, expected_quantized_matmuls = model_info + qconfig = AutoQuantizationConfig.avx512(is_static=False, per_channel=True) + model = model_cls.from_pretrained(model_name, export=True, use_cache=use_cache, use_io_binding=use_cache) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + output_dir = Path(tmp_dir) + quantizer = ORTQuantizer.from_pretrained(model) + quantizer.quantize(save_dir=output_dir, quantization_config=qconfig) + expected_ort_config = ORTConfig(quantization=qconfig) + ort_config = ORTConfig.from_pretrained(tmp_dir) + + # Verify the ORTConfig was correctly created and saved + self.assertEqual(ort_config.to_dict(), expected_ort_config.to_dict()) + quantized_model = onnx_load(output_dir.joinpath("model_quantized.onnx")) + num_quantized_matmul = 0 + for initializer in quantized_model.graph.initializer: + if "weight" in initializer.name and "quantized" in initializer.name: + num_quantized_matmul += 1 + self.assertEqual(expected_quantized_matmuls, num_quantized_matmul) + gc.collect() + class ORTStaticQuantizationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS = ( @@ -184,10 +206,7 @@ def preprocess_function(examples, tokenizer): dataset_split="train", ) calibration_config = AutoCalibrationConfig.minmax(calibration_dataset) - ranges = quantizer.fit( - dataset=calibration_dataset, - calibration_config=calibration_config, - ) + ranges = quantizer.fit(dataset=calibration_dataset, calibration_config=calibration_config) quantizer.quantize( save_dir=output_dir, calibration_tensors_range=ranges, diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 09603c5e1e8..949cfa242e3 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -40,7 +40,7 @@ "clip": "hf-internal-testing/tiny-random-CLIPModel", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", "convnext": "hf-internal-testing/tiny-random-convnext", - "codegen": "hf-internal-testing/tiny-random-CodeGenModel", + "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", "data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel", "data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel", "data2vec_audio": "hf-internal-testing/tiny-random-Data2VecAudioModel", @@ -62,7 +62,7 @@ "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", - "gptj": "hf-internal-testing/tiny-random-GPTJModel", + "gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM", "groupvit": "hf-internal-testing/tiny-random-groupvit", "hubert": "hf-internal-testing/tiny-random-HubertModel", "ibert": "hf-internal-testing/tiny-random-IBertModel",