diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 0faf5048f60..9e808e392b9 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -289,12 +289,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]: if self._behavior is not ConfigBehavior.ENCODER: if self.use_past_in_inputs: common_inputs["decoder_input_ids"] = {0: "batch_size"} + self.add_past_key_values(common_inputs, direction="inputs") else: common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} - if self.use_past_in_inputs: - self.add_past_key_values(common_inputs, direction="inputs") - if self._behavior is ConfigBehavior.DECODER: common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"} diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 26202e889b8..3e11c7e614a 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -53,6 +53,7 @@ NormalizedTextConfig, NormalizedTextConfigWithGQA, NormalizedVisionConfig, + check_if_transformers_greater, is_diffusers_available, logging, ) @@ -71,6 +72,7 @@ ) from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME from .model_patcher import ( + CLIPModelPatcher, FalconModelPatcher, MistralModelPatcher, MusicgenModelPatcher, @@ -913,10 +915,16 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs + def patch_model_for_export( + self, + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> "ModelPatcher": + return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) + class CLIPOnnxConfig(TextAndVisionOnnxConfig): NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig - DEFAULT_ONNX_OPSET = 14 @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -935,6 +943,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]: "image_embeds": {0: "image_batch_size"}, } + def patch_model_for_export( + self, + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> "ModelPatcher": + return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) + class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig): @property @@ -980,6 +995,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]: return common_outputs + def patch_model_for_export( + self, + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> "ModelPatcher": + return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) + class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig): @property @@ -997,12 +1019,20 @@ def outputs(self) -> Dict[str, Dict[int, str]]: def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) + # TODO: fix should be by casting inputs during inference and not export if framework == "pt": import torch dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32) return dummy_inputs + def patch_model_for_export( + self, + model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], + model_kwargs: Optional[Dict[str, Any]] = None, + ) -> "ModelPatcher": + return CLIPModelPatcher(self, model, model_kwargs=model_kwargs) + class UNetOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 @@ -1135,6 +1165,9 @@ class OwlViTOnnxConfig(CLIPOnnxConfig): ATOL_FOR_VALIDATION = 1e-4 MIN_TORCH_VERSION = version.parse("2.1") + # needs einsum operator support, available since opset 12 + DEFAULT_ONNX_OPSET = 12 + def __init__( self, config: "PretrainedConfig", @@ -1438,7 +1471,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]: if self._behavior is not ConfigBehavior.DECODER: common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis. - if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False: + if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs: + if check_if_transformers_greater("4.43.0"): + # since https://github.com/huggingface/transformers/pull/31166 + common_inputs["cache_position"] = {0: "decoder_sequence_length"} + + if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs: common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2" return common_inputs diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 4c1f8458930..34ed5fcae46 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -1138,3 +1138,20 @@ def __init__( self._update_causal_mask_original = self._model.model._update_causal_mask else: self._update_causal_mask_original = self._model._update_causal_mask + + +class CLIPModelPatcher(ModelPatcher): + def __enter__(self): + super().__enter__() + + if _transformers_version >= version.parse("4.43"): + from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention + + self.original_sdpa_forward, CLIPSdpaAttention.forward = CLIPSdpaAttention.forward, CLIPAttention.forward + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if _transformers_version >= version.parse("4.43"): + from transformers.models.clip.modeling_clip import CLIPSdpaAttention + + CLIPSdpaAttention.forward = self.original_sdpa_forward diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 902dd89f777..e2125736c4d 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -96,7 +96,7 @@ def _get_submodels_for_export_diffusion( pipeline, (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline) ) is_stable_diffusion_xl = isinstance( - pipeline, (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline) + pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline) ) is_latent_consistency_model = isinstance( pipeline, (LatentConsistencyModelPipeline, LatentConsistencyModelImg2ImgPipeline) @@ -117,10 +117,11 @@ def _get_submodels_for_export_diffusion( models_for_export = {} # Text encoder - if pipeline.text_encoder is not None: + text_encoder = getattr(pipeline, "text_encoder", None) + if text_encoder is not None: if is_stable_diffusion_xl: - pipeline.text_encoder.config.output_hidden_states = True - models_for_export["text_encoder"] = pipeline.text_encoder + text_encoder.config.output_hidden_states = True + models_for_export["text_encoder"] = text_encoder # U-NET # ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0 @@ -151,6 +152,7 @@ def _get_submodels_for_export_diffusion( text_encoder_2 = getattr(pipeline, "text_encoder_2", None) if text_encoder_2 is not None: text_encoder_2.config.output_hidden_states = True + text_encoder_2.text_model.config.output_hidden_states = True models_for_export["text_encoder_2"] = text_encoder_2 return models_for_export diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 16461dce957..d9877670ba8 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -24,6 +24,7 @@ from ..utils import NormalizedConfigManager from ..utils.logging import warn_once +from .io_binding import TypeHelper from .modeling_ort import ORTModel from .utils import get_ordered_input_names, logging @@ -62,6 +63,20 @@ def __init__( def device(self): return self.parent_model.device + @property + def dtype(self): + for dtype in self.input_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + for dtype in self.output_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + return None + @abstractmethod def forward(self, *args, **kwargs): pass @@ -220,6 +235,7 @@ def forward( encoder_attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, use_cache_branch: None = None, ) -> Seq2SeqLMOutput: # Adding use_cache_branch in the signature here is just a hack for IO Binding @@ -236,8 +252,8 @@ def forward( # no-ops if merged decoder is not used use_merged_no_cache = past_key_values is None and self.parent_model.use_merged use_merged_cache = past_key_values is not None and self.parent_model.use_merged - use_cache_branch_tensor, past_key_values = self.prepare_inputs_for_merged( - input_ids, past_key_values, use_torch=use_torch + use_cache_branch_tensor, past_key_values, cache_position = self.prepare_inputs_for_merged( + input_ids, past_key_values, cache_position, use_torch=use_torch ) if self.parent_model.use_io_binding: @@ -274,6 +290,9 @@ def forward( if use_cache_branch_tensor is not None: model_inputs.append(use_cache_branch_tensor) + if "cache_position" in self.input_names: + model_inputs.append(cache_position) + io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding( self.session, *model_inputs, @@ -346,6 +365,7 @@ def forward( "decoder_attention_mask": decoder_attention_mask, "encoder_attention_mask": encoder_attention_mask, "use_cache_branch": use_cache_branch_tensor, + "cache_position": cache_position, "labels": labels, } if past_key_values is not None: @@ -405,20 +425,20 @@ def forward( def prepare_inputs_for_merged( self, - input_ids: Union[None, torch.LongTensor, np.ndarray], - past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]], + input_ids: Optional[Union[torch.LongTensor, np.ndarray]], + past_key_values: Optional[Tuple[Union[torch.FloatTensor, np.ndarray]]], + cache_position: Optional[Union[torch.Tensor, np.ndarray]], use_torch: bool, ): + constructor = torch if use_torch is True else np + if self.parent_model.use_merged: - constructor = torch if use_torch is True else np # Uses without/with branch of a merged decoder depending on whether real past key values are passed - use_cache_branch = constructor.full((1,), past_key_values is not None) + use_cache_branch_tensor = constructor.full((1,), past_key_values is not None) + if use_torch and use_cache_branch_tensor is not None: + use_cache_branch_tensor = use_cache_branch_tensor.to(self.device) else: - # Uses separate decoders - use_cache_branch = None - - if use_torch and use_cache_branch is not None: - use_cache_branch = use_cache_branch.to(self.device) + use_cache_branch_tensor = None # Generate dummy past for the first forward if uses a merged decoder if self.parent_model.use_merged and past_key_values is None: @@ -434,7 +454,13 @@ def prepare_inputs_for_merged( past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names))) - return use_cache_branch, past_key_values + # Generate dummy position cache for the first forward if uses a merged decoder + if self.parent_model.use_merged and cache_position is None: + cache_position = constructor.zeros((1,), dtype=constructor.int64) + if use_torch is True: + cache_position = cache_position.to(self.device) + + return use_cache_branch_tensor, past_key_values, cache_position class ORTDecoder(ORTDecoderForSeq2Seq): diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index f4e54752115..4bbfb2eda2a 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -452,10 +452,14 @@ def to(self, device: Union[torch.device, str, int]): Returns: `ORTModel`: the model placed on the requested device. """ + device, provider_options = parse_device(device) provider = get_provider_for_device(device) validate_provider_availability(provider) # raise error if the provider is not available - self.device = device + + if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": + return self + self.vae_decoder.session.set_providers([provider], provider_options=[provider_options]) self.text_encoder.session.set_providers([provider], provider_options=[provider_options]) self.unet.session.set_providers([provider], provider_options=[provider_options]) @@ -464,6 +468,8 @@ def to(self, device: Union[torch.device, str, int]): self.vae_encoder.session.set_providers([provider], provider_options=[provider_options]) self.providers = self.vae_decoder.session.get_providers() + self._device = device + return self @classmethod diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 126b1e65366..254b771e334 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -276,7 +276,24 @@ def __init__( self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward) - # TODO: why do we make device a property since we are only access the value, and do not do any check when setting the value? + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the model. + """ + + for dtype in self.input_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + for dtype in self.output_dtypes.values(): + torch_dtype = TypeHelper.ort_type_to_torch_type(dtype) + if torch_dtype.is_floating_point: + return torch_dtype + + return None + @property def device(self) -> torch.device: """ @@ -286,8 +303,8 @@ def device(self) -> torch.device: return self._device @device.setter - def device(self, value: torch.device): - self._device = value + def device(self, **kwargs): + raise AttributeError("The device attribute is read-only, please use the `to` method to change the device.") @property def use_io_binding(self): @@ -309,13 +326,13 @@ def to(self, device: Union[torch.device, str, int]): Returns: `ORTModel`: the model placed on the requested device. """ + device, provider_options = parse_device(device) if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider": return self - self.device = device - provider = get_provider_for_device(self.device) + provider = get_provider_for_device(device) validate_provider_availability(provider) # raise error if the provider is not available # IOBinding is only supported for CPU and CUDA Execution Providers. @@ -331,6 +348,7 @@ def to(self, device: Union[torch.device, str, int]): self.model.set_providers([provider], provider_options=[provider_options]) self.providers = self.model.get_providers() + self._device = device return self diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 3b1af05d0f5..4ce3e4707ed 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -16,7 +16,6 @@ Transformers. """ -import copy import logging import shutil import warnings @@ -34,13 +33,12 @@ AutoModelForSpeechSeq2Seq, AutoModelForVision2Seq, GenerationConfig, - Pix2StructForConditionalGeneration, # Pix2struct does not support AutoModel + Pix2StructForConditionalGeneration, + WhisperForConditionalGeneration, ) from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward -from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput from transformers.models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES -from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE import onnxruntime as ort @@ -73,6 +71,22 @@ else: from transformers.generation_utils import GenerationMixin + +# if check_if_transformers_greater("4.37.0"): +# # starting from transformers v4.37.0, the whisper generation loop is implemented in the `WhisperGenerationMixin` +# # and it implements many new features including short and long form generation, and starts with 2 init tokens +# from transformers.models.whisper.generation_whisper import WhisperGenerationMixin +# else: + +# class WhisperGenerationMixin(WhisperForConditionalGeneration, GenerationMixin): +# pass + + +if check_if_transformers_greater("4.43.0"): + from transformers.cache_utils import EncoderDecoderCache +else: + EncoderDecoderCache = dict + from huggingface_hub.utils import EntryNotFoundError @@ -1104,6 +1118,14 @@ def _from_transformers( model_save_dir=save_dir, ) + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the model. + """ + + return self.encoder.dtype or self.decoder.dtype + def to(self, device: Union[torch.device, str, int]): """ Changes the ONNX Runtime provider according to the device. @@ -1124,12 +1146,12 @@ def to(self, device: Union[torch.device, str, int]): provider = get_provider_for_device(device) validate_provider_availability(provider) # raise error if the provider is not available - self.device = device self.encoder.session.set_providers([provider], provider_options=[provider_options]) self.decoder.session.set_providers([provider], provider_options=[provider_options]) if self.decoder_with_past is not None: self.decoder_with_past.session.set_providers([provider], provider_options=[provider_options]) self.providers = self.encoder.session.get_providers() + self._device = device return self @@ -1338,6 +1360,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Seq2SeqLMOutput: # Encode if needed : first prediction pass @@ -1354,6 +1377,7 @@ def forward( past_key_values=past_key_values, encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=attention_mask, + cache_position=cache_position, labels=labels, ) @@ -1365,30 +1389,25 @@ def forward( def prepare_inputs_for_generation( self, - input_ids, - attention_mask=None, + decoder_input_ids, past_key_values=None, + attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs, - ) -> Dict: + ): + # cut decoder_input_ids if past is used if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - input_ids = input_ids[:, remove_prefix_length:] + decoder_input_ids = decoder_input_ids[:, -1:] return { - "decoder_input_ids": input_ids, - "past_key_values": past_key_values, "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, @@ -1422,11 +1441,17 @@ def _from_pretrained( return super()._from_pretrained(model_id, config, **kwargs) -class _ORTModelForWhisper(ORTModelForSpeechSeq2Seq): +class _ORTModelForWhisper(ORTModelForSpeechSeq2Seq, WhisperForConditionalGeneration): """ Whisper implements its own generate() method. """ + auto_model_class = WhisperForConditionalGeneration + + # force the use of the WhisperForConditionalGeneration generate and prepare_inputs_for_generation methods + prepare_inputs_for_generation = WhisperForConditionalGeneration.prepare_inputs_for_generation + generate = WhisperForConditionalGeneration.generate + @classmethod def _from_pretrained( cls, @@ -1436,418 +1461,22 @@ def _from_pretrained( ): return super(ORTModelForSpeechSeq2Seq, cls)._from_pretrained(model_id, config, **kwargs) - # Adapted from transformers.models.whisper.modeling_whisper - def generate( - self, - input_features: Optional[torch.Tensor] = None, - generation_config=None, - logits_processor=None, - stopping_criteria=None, - prefix_allowed_tokens_fn=None, - synced_gpus=False, - return_timestamps=None, - task=None, - language=None, - is_multilingual=None, - prompt_ids: Optional[torch.Tensor] = None, - num_segment_frames: Optional[int] = None, - return_token_timestamps: Optional[bool] = None, - return_segments: bool = False, - attention_mask: Optional[torch.Tensor] = None, - time_precision: int = 0.02, - return_dict_in_generate: Optional[bool] = None, - **kwargs, - ): - if "inputs" in kwargs: - input_features = kwargs.pop("inputs") - warnings.warn( - "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", - FutureWarning, - ) - - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - - if generation_config is None: - generation_config = copy.deepcopy(self.generation_config) - - input_stride = ( - 1 * 2 - ) # NOTE: replaced from `self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]` - if num_segment_frames is None: - num_segment_frames = input_stride * self.config.max_source_positions - - # 1. Check whether we're in shortform or longform mode - if input_features is not None: - total_input_frames = input_features.shape[-1] - elif "encoder_outputs" in kwargs: - encoder_outputs_shape = ( - kwargs["encoder_outputs"][0].shape - if isinstance(kwargs["encoder_outputs"], BaseModelOutput) - else kwargs["encoder_outputs"].shape - ) - total_input_frames = encoder_outputs_shape[1] * input_stride - else: - raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") - - is_shortform = total_input_frames <= num_segment_frames - - # 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not - if return_timestamps is True: - if not hasattr(generation_config, "no_timestamps_token_id"): - raise ValueError( - "You are trying to return timestamps, but the generation config is not properly set. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" - ) - generation_config.return_timestamps = return_timestamps - elif not is_shortform: - if return_timestamps is False: - raise ValueError( - "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " - "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." - ) - - if not hasattr(generation_config, "no_timestamps_token_id"): - raise ValueError( - "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " - "requires the generation config to have `no_timestamps_token_id` correctly. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" - "or make sure to pass no more than 3000 mel input features." - ) - - logger.info("Setting `return_timestamps=True` for long-form generation.") - generation_config.return_timestamps = True - else: - generation_config.return_timestamps = False - - # 3. Make sure to correctly set language-related parameters - if is_multilingual is not None: - if not hasattr(generation_config, "is_multilingual"): - raise ValueError( - "The generation config is outdated and is thus not compatible with the `is_multilingual` argument " - "to `generate`. Please update the generation config as per the instructions " - "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" - ) - generation_config.is_multilingual = is_multilingual - - if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual: - if task is not None or language is not None: - raise ValueError( - "Cannot specify `task` or `language` for an English-only model. If the model is intended to be " - "multilingual, pass `is_multilingual=True` to generate, or update the generation config." - ) - - if language is not None: - if not hasattr(generation_config, "lang_to_id"): - raise ValueError( - "The generation config is outdated and is thus not compatible with the `language` argument " - "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, " - "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" - ) - language = language.lower() - generation_config.language = language - if task is not None: - if not hasattr(generation_config, "task_to_id"): - raise ValueError( - "The generation config is outdated and is thus not compatible with the `task` argument " - "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, " - "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" - ) - generation_config.task = task - - # 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps` - forced_decoder_ids = None - # Legacy code for backward compatibility - if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: - forced_decoder_ids = self.config.forced_decoder_ids - elif ( - hasattr(self.generation_config, "forced_decoder_ids") - and self.generation_config.forced_decoder_ids is not None - ): - forced_decoder_ids = self.generation_config.forced_decoder_ids - else: - forced_decoder_ids = kwargs.get("forced_decoder_ids", None) - - if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): - forced_decoder_ids = [] - if hasattr(generation_config, "language"): - if generation_config.language in generation_config.lang_to_id.keys(): - language_token = generation_config.language - elif generation_config.language in TO_LANGUAGE_CODE.keys(): - language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" - elif generation_config.language in TO_LANGUAGE_CODE.values(): - language_token = f"<|{generation_config.language}|>" - else: - is_language_code = len(generation_config.language) == 2 - raise ValueError( - f"Unsupported language: {generation_config.language}. Language should be one of:" - f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." - ) - forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) - else: - forced_decoder_ids.append((1, None)) # automatically detect the language - - if hasattr(generation_config, "task"): - if generation_config.task in TASK_IDS: - forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) - else: - raise ValueError( - f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" - ) - elif hasattr(generation_config, "task_to_id"): - forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe - if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: - idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 - forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) - - if forced_decoder_ids is not None: - generation_config.forced_decoder_ids = forced_decoder_ids - - if prompt_ids is not None: - if kwargs.get("decoder_start_token_id") is not None: - raise ValueError( - "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." - ) - prompt_ids = prompt_ids.tolist() - decoder_start_token_id, *text_prompt_ids = prompt_ids - # Slicing the text prompt ids in a manner consistent with the OpenAI implementation - # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) - text_prompt_ids = text_prompt_ids[-self.config.max_target_positions // 2 - 1 :] - # Set the decoder_start_token_id to <|startofprev|> - kwargs.update({"decoder_start_token_id": decoder_start_token_id}) - - # If the user passes `max_new_tokens`, increase its number to account for the prompt - if kwargs.get("max_new_tokens", None) is not None: - kwargs["max_new_tokens"] += len(text_prompt_ids) - if kwargs["max_new_tokens"] >= self.config.max_target_positions: - raise ValueError( - f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` " - f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced " - f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the " - f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " - "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " - f"so that their combined length is less that {self.config.max_target_positions}." - ) - - # Reformat the forced_decoder_ids to incorporate the prompt - non_prompt_forced_decoder_ids = ( - kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids - ) - forced_decoder_ids = [ - *text_prompt_ids, - generation_config.decoder_start_token_id, - *[token for _rank, token in non_prompt_forced_decoder_ids], - ] - forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] - generation_config.forced_decoder_ids = forced_decoder_ids - - if return_token_timestamps: - kwargs["output_attentions"] = True - return_dict_in_generate = True - - if getattr(generation_config, "task", None) == "translate": - logger.warning("Token-level timestamps may not be reliable for task 'translate'.") - if not hasattr(generation_config, "alignment_heads"): - raise ValueError( - "Model generation config has no `alignment_heads`, token-level timestamps not available. " - "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." - ) - - if kwargs.get("num_frames") is not None: - generation_config.num_frames = kwargs.pop("num_frames") - - if generation_config.return_timestamps is True: - last_forced_decoder_ids = ( - generation_config.forced_decoder_ids[-1][-1] - if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids - else None - ) - if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id: - # remove no_timestamp to be forcefully generated if we want to return timestamps - # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly - forced_decoder_ids = generation_config.forced_decoder_ids[:-1] - # Make sure that if list is empty we set it to None - generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids - - timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)] - logits_processor = ( - timestamp_processor if logits_processor is None else timestamp_processor + logits_processor - ) - - # 5. If we're in shortform mode, simple generate the whole input at once and return the output - if is_shortform: - outputs = super().generate( - input_features, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - ) - - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - outputs["token_timestamps"] = self._extract_token_timestamps( - outputs, generation_config.alignment_heads, num_frames=num_frames - ) - - return outputs - - # 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated - # timestamp tokens - # 6.1 Set running parameters for while loop - if not return_segments and return_dict_in_generate: - raise ValueError( - "Make sure to set `return_segments=True` to return generation outputs as part of the `'segments' key.`" - ) - - # if input is longer than 30 seconds we default to long-form generation - timestamp_begin = self.generation_config.no_timestamps_token_id + 1 - # input stride is mel frames per encoder output vector which is the product of all conv strides - batch_size = input_features.shape[0] - - if batch_size > 1 and attention_mask is None: - raise ValueError( - "When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " - ) - elif batch_size > 1: - max_frames = attention_mask.sum(-1).cpu().to(torch.long) - seek = torch.zeros((batch_size,), dtype=torch.long) - else: - max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames - seek = torch.zeros((1,), dtype=torch.long) - - current_segments = [[] for _ in range(batch_size)] - cur_to_prev_index_map = list(range(batch_size)) - - # batch size can decrease during the run - cur_bsz = prev_bsz = batch_size - - # 6.2 Transcribe audio until we reach the end of all input audios - while (seek < max_frames).any(): - prev_bsz = cur_bsz - - # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop - # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order - # to know which original audio is being decoded - new_cur_to_prev_index_map = [] - for i in range(prev_bsz): - prev_i = cur_to_prev_index_map[i] - if seek[prev_i] >= max_frames[prev_i]: - cut_index = i + (cur_bsz - prev_bsz) - cur_bsz -= 1 - input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0) - else: - # cut out index that goes away - new_cur_to_prev_index_map.append(prev_i) - - # 6.4 Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk - cur_to_prev_index_map = new_cur_to_prev_index_map - time_offset = seek * time_precision / input_stride - seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) - - # 6.5 Make sure that all inputs are padded to the same input length - segment_input = [] - for i in range(cur_bsz): - prev_i = cur_to_prev_index_map[i] - segment_input_slice = input_features[ - i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i] - ] - - if segment_input_slice.shape[-1] < num_segment_frames: - # pad to 3000 if necessary - segment_input_slice = torch.nn.functional.pad( - segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1]) - ) - - segment_input.append(segment_input_slice) - - segment_input = torch.cat(segment_input, dim=0) - - # 6.6 Batch generate current chunk - seek_outputs = super().generate( - segment_input, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - ) - - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - seek_outputs["token_timestamps"] = self._extract_token_timestamps( - seek_outputs, generation_config.alignment_heads, num_frames=num_frames - ) - - if return_dict_in_generate: - seek_sequences = seek_outputs["sequences"] - seek_outputs = [ - {k: v[i] for k, v in seek_outputs.items()} - for i in range(next(iter(seek_outputs.values())).size(0)) - ] - else: - seek_sequences = seek_outputs - - # 6.7 Loop over each decoded audio individually as each decoding can be of a different length - for i, seek_sequence in enumerate(seek_sequences): - prev_i = cur_to_prev_index_map[i] - - # make sure we cut a predicted EOS token if we are not finished with the generation yet - is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] - if is_not_final and seek_sequence[-1] == self.generation_config.eos_token_id: - seek_sequence = seek_sequence[:-1] - - # remove all padding tokens - if seek_sequence[-1] == self.generation_config.pad_token_id: - num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] - - segments, segment_offset = self._retrieve_segment( - seek_sequence=seek_sequence, - seek_outputs=seek_outputs, - time_offset=time_offset, - timestamp_begin=timestamp_begin, - seek_num_frames=seek_num_frames, - cur_bsz=cur_bsz, - time_precision=time_precision, - input_stride=input_stride, - prev_idx=prev_i, - idx=i, - ) - - current_segments[prev_i] += segments - seek[prev_i] += segment_offset - - # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted - # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output - sequences = [] - max_total_length = 0 - for current_segment_list in current_segments: - sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1)) - max_total_length = max(max_total_length, len(sequences[-1])) - - for i in range(batch_size): - sequences[i] = torch.nn.functional.pad( - sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id - ) + class DummyWhisperModel: + def __init__(self): + self.encoder = self.Encoder() - sequences = torch.stack(sequences, dim=0) + class Encoder: + def __init__(self): + self.conv1 = self.Conv(stride=(1,)) + self.conv2 = self.Conv(stride=(2,)) - # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. - if return_segments: - return {"sequences": sequences, "segments": current_segments} + class Conv: + def __init__(self, stride): + self.stride = stride - return sequences + # a dummy model attribute that's used in the generate method to compute the input stride + # input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] + model = DummyWhisperModel() @add_end_docstrings(ONNX_MODEL_END_DOCSTRING) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index ec27fe8db4b..36913f652a8 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -555,6 +555,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int None, None, ) + return super().generate(input_name, framework=framework, int_dtype=int_dtype) @@ -610,7 +611,7 @@ class DummySeq2SeqPastKeyValuesGenerator(DummyInputGenerator): Generates dummy past_key_values inputs for seq2seq architectures. """ - SUPPORTED_INPUT_NAMES = ("past_key_values",) + SUPPORTED_INPUT_NAMES = ("past_key_values", "cache_position") def __init__( self, @@ -658,27 +659,38 @@ def __init__( self.decoder_num_layers = self.normalized_config.decoder_num_layers def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - encoder_shape = ( - self.batch_size, - self.encoder_num_attention_heads, - self.encoder_sequence_length, - self.encoder_hidden_size // self.encoder_num_attention_heads, - ) - decoder_shape = ( - self.batch_size, - self.decoder_num_attention_heads, - self.sequence_length, - self.decoder_hidden_size // self.decoder_num_attention_heads, - ) - return [ - ( - self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), - self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), - self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype), - self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype), + if input_name == "past_key_values": + encoder_shape = ( + self.batch_size, + self.encoder_num_attention_heads, + self.encoder_sequence_length, + self.encoder_hidden_size // self.encoder_num_attention_heads, ) - for _ in range(self.decoder_num_layers) - ] + decoder_shape = ( + self.batch_size, + self.decoder_num_attention_heads, + self.sequence_length, + self.decoder_hidden_size // self.decoder_num_attention_heads, + ) + return [ + ( + self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.decoder_num_layers) + ] + + elif input_name == "cache_position": + return self.random_int_tensor( + shape=[1], + max_value=self.sequence_length, + framework=framework, + dtype=int_dtype, + ) + + raise ValueError(f"Unsupported input name {input_name}") # TODO: should it just be merged to DummyTextInputGenerator? diff --git a/setup.py b/setup.py index ce7e537330e..3043eeee6c8 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ REQUIRED_PKGS = [ "coloredlogs", "sympy", - "transformers[sentencepiece]>=4.29.0,<4.43.0", + "transformers[sentencepiece]>=4.29.0,<4.44.0", "torch>=1.11", "packaging", "numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569 diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index 6e7ff71ddd9..e9e2edd9790 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -27,7 +27,7 @@ MODELS_DICT = { "albert": "hf-internal-testing/tiny-random-AlbertModel", - "bark": "ylacombe/bark-small", # TODO: put a smaller model, this one is 1.7GB... + "bark": "ylacombe/bark-small", "bart": "hf-internal-testing/tiny-random-bart", "bert": "hf-internal-testing/tiny-random-BertModel", "bert-generation": "ybelkada/random-tiny-BertGenerationModel", @@ -359,7 +359,8 @@ def _test_save_load_invertible(self, model_id, keep_original_model=True): for name, param in bt_model.named_parameters(): self.assertFalse(param.device.type == "meta", f"Parameter {name} is on the meta device.") - bt_model.save_pretrained(tmpdirname) + # saving a normal transformers bark model fails because of shared tensors + bt_model.save_pretrained(tmpdirname, safe_serialization=hf_model.config.model_type != "bark") bt_model_from_load = AutoModel.from_pretrained(tmpdirname) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 473ab4cf3b8..4b44acb38ab 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4339,25 +4339,25 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach set_seed(SEED) transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id) - processor = get_preprocessor(model_id) + processor = get_preprocessor(model_id) data = self._generate_random_audio_data() - - features = processor.feature_extractor(data, return_tensors="pt") + features = { + "np": processor.feature_extractor(data, return_tensors="np"), + "pt": processor.feature_extractor(data, return_tensors="pt"), + } decoder_start_token_id = transformers_model.config.decoder_start_token_id - decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id} + decoder_inputs = { + "np": {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id}, + "pt": {"decoder_input_ids": torch.ones((1, 1), dtype=torch.int64) * decoder_start_token_id}, + } with torch.no_grad(): - transformers_outputs = transformers_model(**features, **decoder_inputs) + transformers_outputs = transformers_model(**features["pt"], **decoder_inputs["pt"]) for input_type in ["pt", "np"]: - features = processor.feature_extractor(data, return_tensors=input_type) - - if input_type == "np": - decoder_inputs = {"decoder_input_ids": np.ones((1, 1), dtype=np.int64) * decoder_start_token_id} - - onnx_outputs = onnx_model(**features, **decoder_inputs) + onnx_outputs = onnx_model(**features[input_type], **decoder_inputs[input_type]) self.assertTrue("logits" in onnx_outputs) self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type]) @@ -4365,6 +4365,27 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach # Compare tensor outputs self.assertTrue(torch.allclose(torch.Tensor(onnx_outputs.logits), transformers_outputs.logits, atol=1e-4)) + new_tokens = 20 # because tiny random speech to text model has a max_position_embeddings of 20 + + with torch.no_grad(): + transformers_outputs = transformers_model.generate( + **features["pt"], + max_new_tokens=new_tokens, + min_new_tokens=new_tokens, + do_sample=False, + num_beams=1, + ) + + onnx_outputs = onnx_model.generate( + **features["pt"], + max_new_tokens=new_tokens, + min_new_tokens=new_tokens, + do_sample=False, + num_beams=1, + ) + + self.assertTrue(torch.equal(onnx_outputs, transformers_outputs)) + gc.collect() @parameterized.expand(grid_parameters(FULL_GRID)) @@ -4473,7 +4494,7 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): generation_length = self.GENERATION_LENGTH self.GENERATION_LENGTH = 10 - _ = model_with_pkv.generate(**features) # warpup + _ = model_with_pkv.generate(**features) # warmup with Timer() as with_pkv_timer: outputs_model_with_pkv = model_with_pkv.generate( **features, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 @@ -4482,15 +4503,22 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): model_without_pkv = ORTModelForSpeechSeq2Seq.from_pretrained( self.onnx_model_dirs[model_arch + "_False"], use_cache=False ) - _ = model_without_pkv.generate(**features) # warpup + _ = model_without_pkv.generate(**features) # warmup with Timer() as without_pkv_timer: outputs_model_without_pkv = model_without_pkv.generate( **features, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 ) self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) - self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + 1) - self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + 1) + self.assertEqual( + outputs_model_with_pkv.shape[1], + self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, + ) + self.assertEqual( + outputs_model_without_pkv.shape[1], + self.GENERATION_LENGTH + 2 if model_arch == "whisper" else self.GENERATION_LENGTH + 1, + ) + self.GENERATION_LENGTH = generation_length if os.environ.get("TEST_LEVEL", 0) == "1": self.assertTrue( @@ -4547,7 +4575,6 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode ) self.GENERATION_LENGTH = generation_length - self.assertTrue(torch.equal(outputs_model_merged, outputs_model_not_merged)) @parameterized.expand(