From 1c990c622c796a82ada72866efacb44153376cc9 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Thu, 19 Oct 2023 14:06:22 +0200 Subject: [PATCH 01/10] Remove _prepare_decoder_attention_mask patching --- optimum/exporters/onnx/model_configs.py | 24 +-------- optimum/exporters/onnx/model_patcher.py | 66 ------------------------ optimum/exporters/onnx/utils.py | 2 +- optimum/utils/modeling_utils.py | 68 ------------------------- 4 files changed, 2 insertions(+), 158 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e1461c2a0c4..f3846eb04fb 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -61,12 +61,8 @@ VisionOnnxConfig, ) from .model_patcher import ( - BartModelPatcher, BloomModelPatcher, FalconModelPatcher, - LlamaModelPatcher, - MistralModelPatcher, - OPTModelPatcher, SAMModelPatcher, SpeechT5ModelPatcher, VisionEncoderDecoderPatcher, @@ -230,11 +226,6 @@ class OPTOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return OPTModelPatcher(self, model, model_kwargs=model_kwargs) - class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) @@ -242,11 +233,6 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return LlamaModelPatcher(self, model, model_kwargs=model_kwargs) - class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 @@ -257,11 +243,6 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return MistralModelPatcher(self, model, model_kwargs=model_kwargs) - class MPTOnnxConfig(TextDecoderOnnxConfig): # MPT does not require position_ids input. @@ -655,10 +636,7 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): class BartOnnxConfig(M2M100OnnxConfig): - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return BartModelPatcher(self, model, model_kwargs=model_kwargs) + pass class MBartOnnxConfig(BartOnnxConfig): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 4a5f4d1ace4..938ed042a3c 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -27,8 +27,6 @@ from ...utils.modeling_utils import ( _falcon_prepare_attn_mask, _prepare_attn_mask, - _prepare_decoder_attention_mask, - _prepare_decoder_sliding_window_attention_mask, ) @@ -799,67 +797,3 @@ def __init__( self._patch_func = _prepare_attn_mask self._orig_func_name = "_prepare_attn_mask" self._orig_func = self._model_to_patch._prepare_attn_mask - - -class OPTModelPatcher(CausalAttentionMaskModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - - if self.patch: - self._model_to_patch = model.model.decoder - self._patch_func = _prepare_decoder_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" - self._orig_func = self._model_to_patch._prepare_decoder_attention_mask - - -class LlamaModelPatcher(CausalAttentionMaskModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - - if self.patch: - self._model_to_patch = model.model - self._patch_func = _prepare_decoder_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" - self._orig_func = self._model_to_patch._prepare_decoder_attention_mask - - -class MistralModelPatcher(CausalAttentionMaskModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - - if self.patch: - self._model_to_patch = model.model - self._patch_func = _prepare_decoder_sliding_window_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" - self._orig_func = self._model_to_patch._prepare_decoder_attention_mask - - -class BartModelPatcher(CausalAttentionMaskModelPatcher, Seq2SeqModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - - if self.patch: - self._model_to_patch = model.model.decoder - self._patch_func = _prepare_decoder_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" - self._orig_func = self._model_to_patch._prepare_decoder_attention_mask diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index ef6206e8d06..2e49020373f 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -30,7 +30,7 @@ logging, ) from ...utils.import_utils import _diffusers_version -from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask # noqa: F401 +from ...utils.modeling_utils import _prepare_attn_mask # noqa: F401 from ..tasks import TasksManager from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 336ad31e5a7..019f93c6f18 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -112,74 +112,6 @@ def _prepare_attn_mask( return combined_attention_mask -# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask -def _prepare_decoder_attention_mask( - self, - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - inputs_embeds: torch.Tensor, - past_key_values_length: int, -): - from transformers.models.llama.modeling_llama import _expand_mask - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - - combined_attention_mask = _make_causal_mask( - input_shape, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - dtype=inputs_embeds.dtype, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - -# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask -def _prepare_decoder_sliding_window_attention_mask( - self, - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: int, -): - from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - - combined_attention_mask = _make_sliding_window_causal_mask( - input_shape, - device=inputs_embeds.device, - dtype=inputs_embeds.dtype, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def _falcon_prepare_attn_mask( attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int ) -> torch.BoolTensor: From c26518b2c944d79e66420a928c93e6da725f8d86 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Sat, 21 Oct 2023 13:55:44 +0200 Subject: [PATCH 02/10] Add specific warning for exports with sequence_length set to 1 --- optimum/exporters/onnx/__main__.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index df5c2498eff..6b7a972a6c1 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -317,15 +317,6 @@ def main_export( model_name_or_path, subfolder=subfolder, library_name=library_name ) - # get the shapes to be used to generate dummy inputs - input_shapes = {} - for input_name in DEFAULT_DUMMY_SHAPES.keys(): - input_shapes[input_name] = ( - kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] - ) - - torch_dtype = None if fp16 is False else torch.float16 - if task == "auto": try: task = TasksManager.infer_task_from_model(model_name_or_path) @@ -338,6 +329,21 @@ def main_export( f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) + # get the shapes to be used to generate dummy inputs + input_shapes = {} + for input_name in DEFAULT_DUMMY_SHAPES.keys(): + input_shapes[input_name] = ( + kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] + ) + if ( + input_name == "sequence_length" + and kwargs_shapes.get(input_name) == 1 + and task.startswith("text-generation") + ): + logger.warning("Exporting with a sequence length of 1 for text generation models is not supported and can yield unexpected results.") + + torch_dtype = None if fp16 is False else torch.float16 + custom_architecture = False if library_name == "transformers": config = AutoConfig.from_pretrained( From 30a922c450a452d9a1f375451c714a20aaa17aa7 Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Thu, 26 Oct 2023 16:32:38 +0000 Subject: [PATCH 03/10] Style --- optimum/exporters/onnx/__main__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 6b7a972a6c1..da018490d69 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -340,7 +340,9 @@ def main_export( and kwargs_shapes.get(input_name) == 1 and task.startswith("text-generation") ): - logger.warning("Exporting with a sequence length of 1 for text generation models is not supported and can yield unexpected results.") + logger.warning( + "Exporting with a sequence length of 1 for text generation models is not supported and can yield unexpected results." + ) torch_dtype = None if fp16 is False else torch.float16 From 2df564de9e211ffaa96d73806bb55a8989055ede Mon Sep 17 00:00:00 2001 From: baskrahmer Date: Tue, 31 Oct 2023 19:39:53 +0000 Subject: [PATCH 04/10] Remove Falcon attention mask patching --- optimum/exporters/onnx/model_patcher.py | 29 +------------------------ optimum/utils/modeling_utils.py | 16 +------------- 2 files changed, 2 insertions(+), 43 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 938ed042a3c..5c9a554265c 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -18,14 +18,12 @@ import types from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union -import transformers from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.falcon.modeling_falcon import FalconModel, build_alibi_tensor +from transformers.models.falcon.modeling_falcon import build_alibi_tensor from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from transformers.utils import is_torch_available from ...utils.modeling_utils import ( - _falcon_prepare_attn_mask, _prepare_attn_mask, ) @@ -395,21 +393,11 @@ class FalconModelPatcher(ModelPatcher): def __enter__(self): self.patch_ops() - transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched - if self.real_config.task == "text-generation": self._model.transformer.forward = types.MethodType( falcon_model_forward_without_kv_reformatting, self._model.transformer ) - # In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length. - if isinstance(self._model, FalconModel): - self._model._prepare_attn_mask = _falcon_prepare_attn_mask - else: - self._model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask - - setattr(self._model, self.orig_forward_name, self.patched_forward) - def __exit__(self, exc_type, exc_value, traceback): self.restore_ops() @@ -420,14 +408,6 @@ def __exit__(self, exc_type, exc_value, traceback): self.original_model_transformer_forward, self._model.transformer ) - transformers.models.falcon.modeling_falcon._make_causal_mask = self.original_make_causal - - # In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length. - if isinstance(self._model, FalconModel): - self._model._prepare_attn_mask = self.original_falcon_prepare_attn_mask - else: - self._model.transformer._prepare_attn_mask = self.original_falcon_prepare_attn_mask - def __init__( self, config: "OnnxConfig", @@ -439,13 +419,6 @@ def __init__( if config.task == "text-generation": self.original_model_transformer_forward = model.transformer.forward - self.original_make_causal = transformers.models.falcon.modeling_falcon._make_causal_mask - - if isinstance(model, FalconModel): - self.original_falcon_prepare_attn_mask = model._prepare_attn_mask - else: - self.original_falcon_prepare_attn_mask = model.transformer._prepare_attn_mask - self._model = model self.orig_forward_name = "forward" if hasattr(self._model, "forward") else "call" diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 019f93c6f18..f2ce82deebf 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -131,20 +131,6 @@ def _falcon_prepare_attn_mask( f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length" f" {past_key_values_length}." ) - combined_attention_mask = None - device = attention_mask.device - _, seq_length = input_shape - - # if seq_length > 1: - # NOTE: we remove here the `if seq_length > 1` to allow to use a single decoder. - combined_attention_mask = _make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length] - expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask + return _expand_mask(attention_mask, past_key_values_length=past_key_values_length) From 780582f67ac66b7e4d975d59f9e7ff9d581ce896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 Nov 2023 10:28:39 +0100 Subject: [PATCH 05/10] lots of cleaning --- optimum/exporters/onnx/__main__.py | 41 ++++++++-------- optimum/exporters/onnx/base.py | 29 ++++++++++-- optimum/exporters/onnx/config.py | 9 ++-- optimum/exporters/onnx/convert.py | 6 +++ optimum/exporters/onnx/model_configs.py | 50 +++++++++++--------- optimum/exporters/onnx/model_patcher.py | 62 ++++++------------------- optimum/exporters/onnx/utils.py | 3 -- optimum/onnxruntime/modeling_decoder.py | 2 +- optimum/utils/modeling_utils.py | 57 ----------------------- 9 files changed, 99 insertions(+), 160 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index da018490d69..4b60be14149 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -24,6 +24,7 @@ from ...commands.export.onnx import parse_args_onnx from ...utils import DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, logging +from ...utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from ..error_utils import AtolError, OutputMatchError, ShapeError from ..tasks import TasksManager @@ -83,16 +84,12 @@ def _get_submodels_and_onnx_configs( onnx_config_constructor = TasksManager.get_exporter_config_constructor( model=model, exporter="onnx", task=task ) - onnx_config_kwargs = {} - if task.startswith("text-generation") and legacy: - onnx_config_kwargs["no_position_ids"] = legacy - onnx_config = onnx_config_constructor( model.config, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors, - **onnx_config_kwargs, + legacy=legacy, ) onnx_config.variant = _variant @@ -329,21 +326,6 @@ def main_export( f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) - # get the shapes to be used to generate dummy inputs - input_shapes = {} - for input_name in DEFAULT_DUMMY_SHAPES.keys(): - input_shapes[input_name] = ( - kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] - ) - if ( - input_name == "sequence_length" - and kwargs_shapes.get(input_name) == 1 - and task.startswith("text-generation") - ): - logger.warning( - "Exporting with a sequence length of 1 for text generation models is not supported and can yield unexpected results." - ) - torch_dtype = None if fp16 is False else torch.float16 custom_architecture = False @@ -390,6 +372,25 @@ def main_export( is_stable_diffusion = "stable-diffusion" in task model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") + # For MODEL_TO_PATCH_FOR_PAST architectures, when exporting the model with an input of sequence length of 1, a tracer that does not handle + # controlflows will trace incorrectly the mask generation, resulting in incorrect attention masks for other sequence lengthss. + # Reference: https://github.com/huggingface/transformers/blob/af3de8d87c717c4bb090f037d0d89413c195a42f/src/transformers/modeling_attn_mask_utils.py#L94 + input_shapes = {} + for input_name in DEFAULT_DUMMY_SHAPES.keys(): + input_shapes[input_name] = ( + kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] + ) + + # TODO: this may be moved rather to the OnnxConfig to avoid bloating this script. + if ( + model_type in MODEL_TO_PATCH_FOR_PAST + and input_name == "sequence_length" + and kwargs_shapes.get(input_name) == 1 + ): + raise ValueError( + f"Exporting with a sequence length of 1 a {model_type} model is not supported and can yield unexpected results." + ) + if legacy and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"): logger.warning( f"legacy=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `legacy=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381" diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 1e5704e8937..c76644ab77d 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -200,6 +200,7 @@ def __init__( preprocessors: Optional[List[Any]] = None, int_dtype: str = "int64", float_dtype: str = "fp32", + legacy: bool = False, ): self.task = task self.int_dtype = int_dtype @@ -209,6 +210,7 @@ def __init__( self._preprocessors = preprocessors self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) self.variant = "default" + self.legacy = legacy def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: """ @@ -565,6 +567,7 @@ def __init__( use_past: bool = False, use_past_in_inputs: bool = False, preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): self.use_past = use_past self.use_past_in_inputs = use_past_in_inputs @@ -572,7 +575,12 @@ def __init__( self.is_merged = False self.use_cache_branch = None super().__init__( - config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=legacy, ) @property @@ -628,11 +636,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): and "attention_mask" in dummy_inputs ): # Obtain the past sequence length from the value instead of the key (Bloom). - past_length = dummy_inputs["past_key_values"][0][1].shape[-2] + past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[-2] dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( dummy_inputs["attention_mask"], - desired_length=past_length + 1, + desired_length=past_present_length, dim=1, dtype=dummy_inputs["attention_mask"].dtype, ) @@ -658,11 +666,15 @@ def overwrite_shape_and_generate_input( # models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name # while models from TextDecoderOnnxConfig use input_ids, hence the check for both + + # TODO: The check `self.task != "text-generation" and not self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs. + # This overwrite may be moved to OnnxSeq2SeqConfigWithPast, but I am afraid it would break encoder-decoder models. if ( self.use_past and self.use_past_in_inputs and self.use_cache_branch is not False and input_name in ["decoder_input_ids", "input_ids", "position_ids"] + and ((self.task == "text-generation" and not self.legacy) or self.task != "text-generation") ): sequence_length = dummy_input_gen.sequence_length # Use a sequence length of 1 when the KV cache is already populated. @@ -768,6 +780,7 @@ def __init__( use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( config=config, @@ -777,6 +790,7 @@ def __init__( use_past=use_past, use_past_in_inputs=use_past_in_inputs, preprocessors=preprocessors, + legacy=legacy, ) self._behavior = behavior @@ -1003,7 +1017,7 @@ class OnnxConfigWithLoss(OnnxConfig, ABC): DUMMY_EXTRA_INPUT_GENERATOR_CLASSES = (DummyLabelsGenerator,) - def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: str = "fp32"): + def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: str = "fp32", legacy: bool = False): self._onnx_config = config self.task = self._onnx_config.task self.int_dtype = int_dtype @@ -1011,6 +1025,7 @@ def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: st self._normalized_config = self._onnx_config._normalized_config self.PATCHING_SPECS = self._onnx_config.PATCHING_SPECS self.variant = "default" + self.legacy = legacy @classmethod def from_onnx_config(cls, config: OnnxConfig) -> "OnnxConfigWithLoss": @@ -1037,7 +1052,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): batch_size = dummy_inputs[input_name].shape[0] # TODO: doesn't this break attention_mask generation? - if isinstance(self._onnx_config, OnnxConfigWithPast) and self._onnx_config.use_past_in_inputs is True: + if ( + isinstance(self._onnx_config, OnnxConfigWithPast) + and self._onnx_config.use_past_in_inputs is True + and self.task != "text-generation" + ): kwargs["sequence_length"] = 1 else: for input_name, dynamic_axes in self._tasks_to_extra_inputs[self.task].items(): diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 7b7d8b19a50..e3ae2a12db4 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -75,7 +75,7 @@ def __init__( use_past: bool = False, use_past_in_inputs: bool = False, preprocessors: Optional[List[Any]] = None, - no_position_ids: bool = False, + legacy: bool = False, ): super().__init__( config=config, @@ -85,9 +85,8 @@ def __init__( use_past=use_past, use_past_in_inputs=use_past_in_inputs, preprocessors=preprocessors, + legacy=legacy, ) - # TODO: remove no_position_ids once optimum is sufficiently above 1.13 - self.no_position_ids = no_position_ids @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -163,7 +162,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # Decoders based on GPT2 require a position_ids input to avoid # generating wrong position_ids in the model itself: # https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802 - if not self.no_position_ids and self.task in ["text-generation", "feature-extraction"]: + if not self.legacy and self.task in ["text-generation", "feature-extraction"]: common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} return common_inputs @@ -316,6 +315,7 @@ def __init__( use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( config=config, @@ -326,6 +326,7 @@ def __init__( use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=preprocessors, + legacy=legacy, ) from ..tasks import TasksManager diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 0b00667e6c8..dbff4e1f94f 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -568,6 +568,12 @@ def remap(value): input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) + for name, inp in dummy_inputs.items(): + if isinstance(inp, torch.Tensor): + print(name, inp.shape) + else: + print(name, type(inp)) + # Export can work with named args but the dict containing named args has to be the last element of the args # tuple. onnx_export( diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index f3846eb04fb..5c59cff8bac 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -61,7 +61,6 @@ VisionOnnxConfig, ) from .model_patcher import ( - BloomModelPatcher, FalconModelPatcher, SAMModelPatcher, SpeechT5ModelPatcher, @@ -251,11 +250,6 @@ class MPTOnnxConfig(TextDecoderOnnxConfig): num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers" ) - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return BloomModelPatcher(self, model, model_kwargs=model_kwargs) - class BloomOnnxConfig(TextDecoderOnnxConfig): # Bloom does not require position_ids input. @@ -286,11 +280,6 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire 1: decoder_sequence_name, } - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return BloomModelPatcher(self, model, model_kwargs=model_kwargs) - class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( @@ -322,6 +311,9 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): class FalconOnnxConfig(TextDecoderOnnxConfig): + MIN_TRANSFORMERS_VERSION = version.parse( + "4.34.99" + ) # This is because of the patching that uses _prepare_4d_causal_attention_mask from transformers>=4.35 DUMMY_INPUT_GENERATOR_CLASSES = ( MultiQueryPastKeyValuesGenerator, ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES @@ -338,7 +330,7 @@ def __init__( use_past: bool = False, use_past_in_inputs: bool = False, preprocessors: Optional[List[Any]] = None, - no_position_ids: bool = False, + legacy: bool = False, ): super().__init__( config=config, @@ -348,7 +340,7 @@ def __init__( use_past=use_past, use_past_in_inputs=use_past_in_inputs, preprocessors=preprocessors, - no_position_ids=no_position_ids, + legacy=legacy, ) # For some reason Falcon config.num_kv_heads can not be trusted, see in Transformers: # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L337 @@ -362,11 +354,7 @@ def __init__( def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs - if ( - not self.no_position_ids - and not self._config.alibi - and self.task in ["text-generation", "feature-extraction"] - ): + if not self.legacy and not self._config.alibi and self.task in ["text-generation", "feature-extraction"]: # When alibi is used, position_ids are not used in Falcon. # Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116 common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} @@ -1009,9 +997,15 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( - config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=legacy, ) if task == "zero-shot-object-detection": logger.warning( @@ -1150,9 +1144,15 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( - config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=legacy, ) self.is_generating_dummy_inputs = False @@ -1324,6 +1324,7 @@ def __init__( behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, is_postnet_and_vocoder: bool = False, + legacy: bool = False, ): super().__init__( config=config, @@ -1334,6 +1335,7 @@ def __init__( use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=preprocessors, + legacy=legacy, ) if float_dtype == "fp16": raise ValueError( @@ -1568,9 +1570,15 @@ def __init__( variant: str = "split", vision_encoder: Optional[bool] = None, preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( - config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=legacy, ) self.variant = variant self.vision_encoder = vision_encoder diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 5c9a554265c..bddc48fdd95 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -18,22 +18,25 @@ import types from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union +from packaging import version from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.models.falcon.modeling_falcon import build_alibi_tensor from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from transformers.utils import is_torch_available -from ...utils.modeling_utils import ( - _prepare_attn_mask, -) - if is_torch_available(): import torch +from ...configuration_utils import _transformers_version from ...utils import logging +if _transformers_version > version.parse("4.34.99"): + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +else: + _prepare_4d_causal_attention_mask = None + if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -245,31 +248,6 @@ def __init__( model.decoder.model.decoder.config.use_cache = True -def _make_causal_mask_falcon_patched( - input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int -) -> torch.BoolTensor: - """ - Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it - just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1, - target_length, target_length+past_key_values_length]`. - """ - batch_size, target_length = input_ids_shape - - # NOTE: ONNX Runtime is not able to run ONNX Trilu node with bool input. As a workaround, we pass a float input - # and cast to bool here. Reference: https://github.com/microsoft/onnxruntime/issues/16189 - mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.float, device=device), diagonal=1).to( - torch.bool - ) - - # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op. - # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this - # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later. - past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device) - mask = torch.cat([past_mask, mask], dim=-1) - expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) - return expanded_mask - - def falcon_model_forward_without_kv_reformatting( self, input_ids: Optional[torch.LongTensor] = None, @@ -283,6 +261,8 @@ def falcon_model_forward_without_kv_reformatting( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): + # TODO: We may remove this patch once https://github.com/huggingface/transformers/pull/26933 is merged & released in Transformers. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -342,10 +322,9 @@ def falcon_model_forward_without_kv_reformatting( else: position_ids = position_ids.view(-1, seq_length).long() - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -355,7 +334,7 @@ def falcon_model_forward_without_kv_reformatting( outputs = block( hidden_states, layer_past=layer_past, - attention_mask=causal_mask, + attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, @@ -755,18 +734,3 @@ def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if self.patch: setattr(self._model_to_patch, self._orig_func_name, self._orig_func.__get__(self._model_to_patch)) - - -class BloomModelPatcher(CausalAttentionMaskModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - if self.patch: - self._model_to_patch = model.transformer - self._patch_func = _prepare_attn_mask - self._orig_func_name = "_prepare_attn_mask" - self._orig_func = self._model_to_patch._prepare_attn_mask diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 2e49020373f..b711a8a23af 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -30,7 +30,6 @@ logging, ) from ...utils.import_utils import _diffusers_version -from ...utils.modeling_utils import _prepare_attn_mask # noqa: F401 from ..tasks import TasksManager from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME @@ -255,8 +254,6 @@ def get_decoder_models_for_export( models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy) onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype} - if model.config.model_type.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: - onnx_kwargs["no_position_ids"] = config.no_position_ids if legacy: onnx_config = config.__class__( diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 13aef3546a5..94418a96afe 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -473,7 +473,7 @@ def _from_pretrained( if file_name == ONNX_DECODER_WITH_PAST_NAME and config.model_type in MODEL_TO_PATCH_FOR_PAST: raise ValueError( - f"{ONNX_DECODER_WITH_PAST_NAME} not supported for the following architecture : {', '.join(MODEL_TO_PATCH_FOR_PAST)}. Please re-export your model or set use_cache=False." + f"ONNX Runtime inference using {ONNX_DECODER_WITH_PAST_NAME} has been deprecated for {config.model_type} architecture. Please re-export your model with optimum>=1.14.0 or set use_cache=False. For details about the deprecation, please refer to https://github.com/huggingface/optimum/releases/tag/v1.14.0." ) regular_file_names = [] diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index f2ce82deebf..522a0fcb1f6 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -13,7 +13,6 @@ # limitations under the License. import functools -from typing import Tuple import torch @@ -78,59 +77,3 @@ def _make_causal_mask( ) return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) - - -# NOTE: For MODEL_TO_PATCH_FOR_PAST architectures, when exporting the model with an input of sequence length of 1, the attention masks will be generated incorrectly for other sequence length -# https://github.com/huggingface/transformers/blob/0ee45906845c8d58b9bd2df5acd90e09b00047ff/src/transformers/models/bloom/modeling_bloom.py#L654 -# The method taking care of the decoder mask generation of the models from these architectures must be patched during export for sequence length of 1. - - -# Modified from transformers.models.bloom.modeling_bloom._prepare_attn_mask -def _prepare_attn_mask( - self, - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - past_key_values_length: int, -) -> torch.BoolTensor: - from transformers.models.bloom.modeling_bloom import _expand_mask - - # create causal mask - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - combined_attention_mask = _make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - - -def _falcon_prepare_attn_mask( - attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int -) -> torch.BoolTensor: - from transformers.models.falcon.modeling_falcon import ( - _expand_mask, - ) - - # NOTE: there is no "copied from" for falcon in transformers which makes no sense to me. - - # Create a causal mask - # The attention mask we receive as input should cover the whole extended sequence, including any past - # cache, so its shape should be [batch_size, seq_length + past_key_values_length] - # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length] - if input_shape[1] + past_key_values_length != attention_mask.shape[1]: - raise ValueError( - "Attention mask shape should be (batch_size, seq_length + past_key_values_length)" - f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length" - f" {past_key_values_length}." - ) - - # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length] - return _expand_mask(attention_mask, past_key_values_length=past_key_values_length) From 3f11d1a42750fa51c4f5cea2e79d9f4b62a6ee42 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:25:56 +0100 Subject: [PATCH 06/10] fix mistral --- optimum/exporters/onnx/__main__.py | 4 +- optimum/exporters/onnx/config.py | 9 +++ optimum/exporters/onnx/convert.py | 6 -- optimum/exporters/onnx/model_configs.py | 9 ++- optimum/exporters/onnx/model_patcher.py | 83 ++++++++++++++++++------- 5 files changed, 77 insertions(+), 34 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 4b60be14149..654f9a649e1 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -314,6 +314,8 @@ def main_export( model_name_or_path, subfolder=subfolder, library_name=library_name ) + torch_dtype = None if fp16 is False else torch.float16 + if task == "auto": try: task = TasksManager.infer_task_from_model(model_name_or_path) @@ -326,8 +328,6 @@ def main_export( f"The task could not be automatically inferred as this is available only for models hosted on the Hugging Face Hub. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) - torch_dtype = None if fp16 is False else torch.float16 - custom_architecture = False if library_name == "transformers": config = AutoConfig.from_pretrained( diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index e3ae2a12db4..2eaa78d85e4 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -35,11 +35,14 @@ ) from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME +from .model_patcher import DecoderModelPatcher if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel + from .model_patcher import ModelPatcher + if is_tf_available(): from transformers import TFPreTrainedModel @@ -153,6 +156,12 @@ def post_process_exported_models( return models_and_onnx_configs, onnx_files_subpaths + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + # Refer to DecoderModelPatcher. + return DecoderModelPatcher(self, model, model_kwargs=model_kwargs) + class TextDecoderWithPositionIdsOnnxConfig(TextDecoderOnnxConfig): @property diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index dbff4e1f94f..0b00667e6c8 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -568,12 +568,6 @@ def remap(value): input_names = list(inputs.keys()) output_names = list(config.outputs.keys()) - for name, inp in dummy_inputs.items(): - if isinstance(inp, torch.Tensor): - print(name, inp.shape) - else: - print(name, type(inp)) - # Export can work with named args but the dict containing named args has to be the last element of the args # tuple. onnx_export( diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 5c59cff8bac..1bb998f138c 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -234,6 +234,9 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): + # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35 + MIN_TRANSFORMERS_VERSION = version.parse("4.34.99") + # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 DUMMY_INPUT_GENERATOR_CLASSES = ( @@ -311,9 +314,9 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): class FalconOnnxConfig(TextDecoderOnnxConfig): - MIN_TRANSFORMERS_VERSION = version.parse( - "4.34.99" - ) # This is because of the patching that uses _prepare_4d_causal_attention_mask from transformers>=4.35 + # This is because of the patching that uses _prepare_4d_causal_attention_mask from transformers>=4.35 + MIN_TRANSFORMERS_VERSION = version.parse("4.34.99") + DUMMY_INPUT_GENERATOR_CLASSES = ( MultiQueryPastKeyValuesGenerator, ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index bddc48fdd95..5146d9482a0 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -33,9 +33,10 @@ if _transformers_version > version.parse("4.34.99"): - from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask else: _prepare_4d_causal_attention_mask = None + AttentionMaskConverter = None if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -368,7 +369,64 @@ def falcon_model_forward_without_kv_reformatting( ) -class FalconModelPatcher(ModelPatcher): +def _make_causal_mask_patched( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +): + """ + Make causal mask used for bi-directional self-attention. + """ + # We add self in the signature because `self._make_causal_mask` is used elsewhere in the class definition, despite the method being a staticmethod. + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + # NOTE: adding dtype=torch.int64 here for triu to be supported by ORT: https://github.com/microsoft/onnxruntime/issues/16189 + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int64), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +class DecoderModelPatcher(ModelPatcher): + def __enter__(self): + # TODO: Remove this if once transformers if much above 4.35 + if AttentionMaskConverter is not None: + AttentionMaskConverter._make_causal_mask = _make_causal_mask_patched + + def __exit__(self, exc_type, exc_value, traceback): + # TODO: Remove this if once transformers if much above 4.35 + if AttentionMaskConverter is not None: + AttentionMaskConverter._make_causal_mask = self.original_make_causal + + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + # TODO: Remove this if once transformers if much above 4.35 + if AttentionMaskConverter is not None: + self.original_make_causal = AttentionMaskConverter._make_causal_mask + + +class FalconModelPatcher(DecoderModelPatcher): def __enter__(self): self.patch_ops() @@ -713,24 +771,3 @@ def patched_forward( return filterd_outputs self.patched_forward = patched_forward - - -class CausalAttentionMaskModelPatcher(ModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - self.patch = self.real_config.task == "text-generation" and self.real_config.use_past - - def __enter__(self): - super().__enter__() - if self.patch: - setattr(self._model_to_patch, self._orig_func_name, self._patch_func.__get__(self._model_to_patch)) - - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) - if self.patch: - setattr(self._model_to_patch, self._orig_func_name, self._orig_func.__get__(self._model_to_patch)) From 6c53370c6d4cddff5f2cdcb9553d0b57fbce58dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:29:46 +0100 Subject: [PATCH 07/10] fix legacy --- optimum/exporters/onnx/base.py | 1 + optimum/exporters/onnx/utils.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index c76644ab77d..e0ba78f972a 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -830,6 +830,7 @@ def with_behavior( use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=self._preprocessors, + legacy=legacy, ) onnx_config.variant = self.variant return onnx_config diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index b711a8a23af..dbdcbf18d03 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -253,7 +253,7 @@ def get_decoder_models_for_export( models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy) - onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype} + onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype, "legacy": legacy} if legacy: onnx_config = config.__class__( @@ -386,14 +386,14 @@ def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel models_for_export = _get_submodels_for_export_sam(model, config.variant) if config.variant == "monolith": - onnx_config = config.__class__(model.config, task=config.task) + onnx_config = config.__class__(model.config, task=config.task, legacy=legacy) models_for_export["model"] = (models_for_export["model"], onnx_config) else: vision_encoder_onnx_config = config.__class__( - model.config, task=config.task, variant=config.variant, vision_encoder=True + model.config, task=config.task, variant=config.variant, vision_encoder=True, legacy=legacy ) prompt_encoder_mask_decoder_onnx_config = config.__class__( - model.config, task=config.task, variant=config.variant, vision_encoder=False + model.config, task=config.task, variant=config.variant, vision_encoder=False, legacy=legacy ) models_for_export["vision_encoder"] = (models_for_export["vision_encoder"], vision_encoder_onnx_config) models_for_export["prompt_encoder_mask_decoder"] = ( @@ -451,6 +451,7 @@ def get_speecht5_models_for_export( behavior=config._behavior, # Irrelevant here. preprocessors=config._preprocessors, is_postnet_and_vocoder=True, + legacy=legacy, ) postnet_and_vocoder_onnx_config.variant = config.variant models_for_export["decoder_postnet_and_vocoder"] = ( From 967b8c9105722ef60850292535de325b04e026e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:49:53 +0100 Subject: [PATCH 08/10] more fixes --- optimum/exporters/onnx/base.py | 6 +++--- optimum/exporters/onnx/model_patcher.py | 9 +++++++-- optimum/exporters/onnx/utils.py | 15 ++++++++++----- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index e0ba78f972a..8371953ab70 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -667,14 +667,14 @@ def overwrite_shape_and_generate_input( # models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name # while models from TextDecoderOnnxConfig use input_ids, hence the check for both - # TODO: The check `self.task != "text-generation" and not self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs. + # TODO: The check `self.task != "text-generation" and self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs. # This overwrite may be moved to OnnxSeq2SeqConfigWithPast, but I am afraid it would break encoder-decoder models. if ( self.use_past and self.use_past_in_inputs and self.use_cache_branch is not False and input_name in ["decoder_input_ids", "input_ids", "position_ids"] - and ((self.task == "text-generation" and not self.legacy) or self.task != "text-generation") + and ((self.task == "text-generation" and self.legacy) or self.task != "text-generation") ): sequence_length = dummy_input_gen.sequence_length # Use a sequence length of 1 when the KV cache is already populated. @@ -830,7 +830,7 @@ def with_behavior( use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=self._preprocessors, - legacy=legacy, + legacy=self.legacy, ) onnx_config.variant = self.variant return onnx_config diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 5146d9482a0..f4d8eb2eb71 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -410,8 +410,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): # TODO: Remove this if once transformers if much above 4.35 - if AttentionMaskConverter is not None: - AttentionMaskConverter._make_causal_mask = self.original_make_causal + # TODO: We should unpatch it - however `self._make_causal_mask` may still be called later which raises issues with this simple patch strategy. + # We need to find a proper solution. + # if AttentionMaskConverter is not None: + # AttentionMaskConverter._make_causal_mask = self.original_make_causal + pass def __init__( self, @@ -428,6 +431,7 @@ def __init__( class FalconModelPatcher(DecoderModelPatcher): def __enter__(self): + super().__enter__() self.patch_ops() if self.real_config.task == "text-generation": @@ -436,6 +440,7 @@ def __enter__(self): ) def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) self.restore_ops() setattr(self._model, self.orig_forward_name, self.orig_forward) diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index dbdcbf18d03..c1737fc087c 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -253,7 +253,12 @@ def get_decoder_models_for_export( models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy) - onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype, "legacy": legacy} + onnx_kwargs = { + "task": config.task, + "float_dtype": config.float_dtype, + "int_dtype": config.int_dtype, + "legacy": legacy, + } if legacy: onnx_config = config.__class__( @@ -386,14 +391,14 @@ def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel models_for_export = _get_submodels_for_export_sam(model, config.variant) if config.variant == "monolith": - onnx_config = config.__class__(model.config, task=config.task, legacy=legacy) + onnx_config = config.__class__(model.config, task=config.task, legacy=config.legacy) models_for_export["model"] = (models_for_export["model"], onnx_config) else: vision_encoder_onnx_config = config.__class__( - model.config, task=config.task, variant=config.variant, vision_encoder=True, legacy=legacy + model.config, task=config.task, variant=config.variant, vision_encoder=True, legacy=config.legacy ) prompt_encoder_mask_decoder_onnx_config = config.__class__( - model.config, task=config.task, variant=config.variant, vision_encoder=False, legacy=legacy + model.config, task=config.task, variant=config.variant, vision_encoder=False, legacy=config.legacy ) models_for_export["vision_encoder"] = (models_for_export["vision_encoder"], vision_encoder_onnx_config) models_for_export["prompt_encoder_mask_decoder"] = ( @@ -451,7 +456,7 @@ def get_speecht5_models_for_export( behavior=config._behavior, # Irrelevant here. preprocessors=config._preprocessors, is_postnet_and_vocoder=True, - legacy=legacy, + legacy=config.legacy, ) postnet_and_vocoder_onnx_config.variant = config.variant models_for_export["decoder_postnet_and_vocoder"] = ( From 388128048e1db1b358714d1eb6fb423ad9396f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:11:24 +0100 Subject: [PATCH 09/10] fix make_causal patching --- optimum/exporters/onnx/model_patcher.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index f4d8eb2eb71..09cdddc95fe 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -370,7 +370,6 @@ def falcon_model_forward_without_kv_reformatting( def _make_causal_mask_patched( - self, input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, @@ -402,6 +401,9 @@ def _make_causal_mask_patched( return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) +_make_causal_mask_patched = staticmethod(_make_causal_mask_patched) + + class DecoderModelPatcher(ModelPatcher): def __enter__(self): # TODO: Remove this if once transformers if much above 4.35 @@ -410,11 +412,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): # TODO: Remove this if once transformers if much above 4.35 - # TODO: We should unpatch it - however `self._make_causal_mask` may still be called later which raises issues with this simple patch strategy. - # We need to find a proper solution. - # if AttentionMaskConverter is not None: - # AttentionMaskConverter._make_causal_mask = self.original_make_causal - pass + if AttentionMaskConverter is not None: + AttentionMaskConverter._make_causal_mask = self.original_make_causal def __init__( self, From 05fc5f7adc9c1fff1843b9011479981e5ab845c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:45:15 +0100 Subject: [PATCH 10/10] remove unused method --- optimum/utils/modeling_utils.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 522a0fcb1f6..dae5b5d633a 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -14,8 +14,6 @@ import functools -import torch - MODEL_TO_PATCH_FOR_PAST = { "bart", @@ -54,26 +52,3 @@ def recurse_setattr(module, name, value): else: name, rest = name.split(".", 1) recurse_setattr(getattr(module, name), rest, value) - - -# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - device: torch.device, - past_key_values_length: int, - dtype: torch.dtype = torch.bool, -) -> torch.BoolTensor: - """ - Make causal mask used for bi-directional self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device) - seq_ids = torch.arange(target_length, device=device) - - mask[:, past_key_values_length:] = ( - (seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min - if torch.is_floating_point(mask) - else seq_ids[:, None] < seq_ids[None, :] - ) - - return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)