diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 8cd94194ffe..5e2118b42c8 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -266,7 +266,9 @@ def variant(self, value: str): if value == "default" and hasattr(self, "DEFAULT_VARIANT"): value = self.DEFAULT_VARIANT if value not in self.VARIANTS: - raise ValueError(f"The variant {value} is not supported for the ONNX config {self.__class__.__name__}.") + raise ValueError( + f"The variant {value} is not supported for the ONNX config {self.__class__.__name__}. Available variants {self.VARIANTS.keys()}" + ) self._variant = value def fix_dynamic_axes( @@ -645,7 +647,8 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): and "attention_mask" in dummy_inputs ): # Obtain the past sequence length from the value instead of the key (Bloom). - past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[-2] + input_name = "inputs_embeds" if "inputs_embeds" in dummy_inputs else "input_ids" + past_present_length = dummy_inputs[input_name].shape[1] + dummy_inputs["past_key_values"][0][1].shape[-2] dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( dummy_inputs["attention_mask"], diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e23716d4b74..bab9a7b0645 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -15,7 +15,7 @@ """Model specific ONNX configurations.""" import random from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union from packaging import version from transformers.utils import is_tf_available @@ -72,6 +72,7 @@ from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME from .model_patcher import ( FalconModelPatcher, + LlavaModelPatcher, MusicgenModelPatcher, SAMModelPatcher, SentenceTransformersCLIPPatcher, @@ -274,7 +275,7 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator - NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) class Qwen2OnnxConfig(LlamaOnnxConfig): @@ -976,6 +977,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): return dummy_inputs +class CLIPVisionOnnxConfig(ViTOnnxConfig): + pass + + class UNetOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 # The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu @@ -2239,3 +2244,260 @@ def overwrite_shape_and_generate_input( class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig + + +class LlavaOnnxConfig(OnnxConfigWithPast): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) + NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig + DEFAULT_ONNX_OPSET = 14 + + VARIANTS = { + "default": "The export follows the Transformers implementation of forward in LlavaModelForConditionalGeneration, with the following components exported:\n\t - " + "model.onnx: corresponds to the vision encoder + projection + decoder in a single file without past key value support in https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L360-L519.\n\t - " + "decoder_model.onnx: corresponds to the decoder part in with past_key_values input https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L449-L489.", + "optimized": "The export follows the memory optimized implementation of Transformers forward. This is a recommended export as decoder is exported only once`. It has the following components exported:\n\t - " + "encoder_model.onnx: corresponds to the vision encoder + projection + decoder in https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L421-L445.\n\t - " + "decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L480-L489.\n\t - " + "decoder_input_processor.onnx: corresponds to decoder input generation when past_key_values is provided in https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L421-L478.", + } + + DEFAULT_VARIANT = "optimized" + + def __init__( + self, + config: "PretrainedConfig", + task: str = "image-to-text-with-past", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + behavior: ConfigBehavior = ConfigBehavior.MONOLITH, + preprocessors: Optional[List[Any]] = None, + variant: str = "default", + legacy: bool = False, + decoder_input_processor_export: Optional[bool] = None, + ): + super().__init__(config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, preprocessors, legacy) + + if legacy: + raise ValueError("LLavaOnnxConfig is only supported in legacy mode.") + + self._behavior = behavior + self.variant = variant + self.decoder_input_processor_export = decoder_input_processor_export + + if variant == "default" and behavior is ConfigBehavior.ENCODER: + raise ValueError(f"LLava does not support encoder-only export for variant {variant}.") + + # Local import to avoid circular imports. + from optimum.exporters.tasks import TasksManager + + # Set up the encoder ONNX config. + encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor( + exporter="onnx", + task="feature-extraction", + model_type=config.vision_config.model_type.replace("_", "-"), + library_name="transformers", + ) + self._encoder_onnx_config = encoder_onnx_config_constructor( + config.vision_config, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors + ) + + self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = self._encoder_onnx_config._normalized_config + + # Set up the decoder ONNX config. + task = "text-generation-with-past" if use_past else "text-generation" + decoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor( + exporter="onnx", + task="feature-extraction", + model_type=config.text_config.model_type.replace("_", "-"), + library_name="transformers", + ) + self._decoder_onnx_config = decoder_onnx_config_constructor( + config.text_config, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + ) + + self.is_decoder_with_past = issubclass(decoder_onnx_config_constructor.func, OnnxConfigWithPast) + if not self.is_decoder_with_past: + raise ValueError("LLava does not support decoder without past_key_values input.") + + self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config + + self.DUMMY_INPUT_GENERATOR_CLASSES += self._decoder_onnx_config.DUMMY_INPUT_GENERATOR_CLASSES + self.DUMMY_PKV_GENERATOR_CLASS = self._decoder_onnx_config.DUMMY_PKV_GENERATOR_CLASS + + def with_behavior( + self, + behavior: Union[str, ConfigBehavior], + use_past: bool = False, + use_past_in_inputs: bool = False, + decoder_input_processor_export: Optional[bool] = None, + ) -> OnnxConfigWithPast: + if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior): + behavior = ConfigBehavior(behavior) + + onnx_config = self.__class__( + self._config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + behavior=behavior, + preprocessors=self._preprocessors, + variant=self.variant, + legacy=self.legacy, + decoder_input_processor_export=decoder_input_processor_export, + ) + return onnx_config + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = { + "pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + + if self.variant == "transformers": + if self._behavior is ConfigBehavior.DECODER: + common_inputs["input_ids"] = {0: "batch_size"} + + if self.use_past_in_inputs: + self.add_past_key_values(common_inputs, direction="inputs") + + elif self.variant == "optimized": + if self._behavior is ConfigBehavior.DECODER: + common_inputs = { + "inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length", 2: "hidden_size"}, + "attention_mask": {0: "batch_size", 1: "decoder_sequence_length+past_sequence_length"}, + "position_ids": {0: "batch_size", 1: "decoder_sequence_length"}, + } + + if self.use_past_in_inputs: + self.add_past_key_values(common_inputs, direction="inputs") + + if self.decoder_input_processor_export is True: + common_inputs.pop("inputs_embeds") + common_inputs.pop("position_ids") + common_inputs["input_ids"] = {0: "batch_size"} + common_inputs["attention_mask"] = common_inputs.pop("attention_mask") + + pkv_names = [key for key in common_inputs.keys() if key.startswith("past_key_values")][1:] + for pkv_name in pkv_names: + common_inputs.pop(pkv_name) + + return common_inputs + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self.variant == "transformers": + outputs = { + "logits": {0: "batch_size", 1: "decoder_sequence_length", 2: "vocab_size"}, + } + if self.use_past: + self.add_past_key_values(outputs, direction="outputs") + elif self.variant == "optimized": + if self._behavior is ConfigBehavior.ENCODER: + outputs = { + "inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length", 2: "hidden_size"}, + "decoder_attention_mask": {0: "batch_size", 1: "decoder_sequence_length"}, + "position_ids": {0: "batch_size", 1: "decoder_sequence_length"}, + } + elif self._behavior is ConfigBehavior.DECODER and self.decoder_input_processor_export is True: + outputs = { + "inputs_embeds": {0: "batch_size", 2: "hidden_size"}, + "decoder_attention_mask": {0: "batch_size", 1: "past_decoder_sequence_length + 1"}, + "position_ids": {0: "batch_size"}, + } + elif self._behavior is ConfigBehavior.DECODER: + outputs = { + "logits": {0: "batch_size", 1: "decoder_sequence_length", 2: "vocab_size"}, + } + if self.use_past: + self.add_past_key_values(outputs, direction="outputs") + + return outputs + + def overwrite_shape_and_generate_input( + self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict + ): + if self.use_past and self.use_past_in_inputs and input_name == "input_ids": + if self.variant == "default" or ( + self.variant == "optimized" and self.decoder_input_processor_export is True + ): + sequence_length = dummy_input_gen.sequence_length + # Use a sequence length of 1 when the KV cache is already populated. + dummy_input_gen.sequence_length = 1 + dummy_input = dummy_input_gen.generate( + input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype + ) + dummy_input_gen.sequence_length = sequence_length + else: + dummy_input = dummy_input_gen.generate( + input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype + ) + + return dummy_input + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + if self.is_decoder_with_past: + return self._decoder_onnx_config.add_past_key_values(inputs_or_outputs, direction) + + def flatten_past_key_values(self, flattened_output, name, idx, t): + if self.is_decoder_with_past: + return self._decoder_onnx_config.flatten_past_key_values(flattened_output, name, idx, t) + + def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]: + return self._decoder_onnx_config.flatten_output_collection_property(name, field) + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + self.PAD_ATTENTION_MASK_TO_PAST = self._decoder_onnx_config.PAD_ATTENTION_MASK_TO_PAST + + dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) + + if "pixel_values" in dummy_inputs: + input_ids = dummy_inputs["input_ids"] + mask = input_ids == self._config.image_token_index + input_ids[mask] = self._config.pad_token_id + + if self._behavior is ConfigBehavior.MONOLITH or self._behavior is ConfigBehavior.ENCODER: + input_ids[:, 1] = self._config.image_token_index + + dummy_inputs["input_ids"] = input_ids + + if ( + self.variant == "optimized" + and self._behavior is ConfigBehavior.DECODER + and self.decoder_input_processor_export is True + ): + dummy_inputs["past_key_values"] = dummy_inputs["past_key_values"][0][0][:, :, :, 0] + + return dummy_inputs + + def generate_dummy_inputs_for_validation( + self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None + ) -> Dict[str, Any]: + dummy_inputs = super().generate_dummy_inputs_for_validation(reference_model_inputs, onnx_input_names) + + if self.variant == "default" and self._behavior is ConfigBehavior.DECODER: + dummy_inputs.pop("pixel_values") + + if ( + self.variant == "optimized" + and self._behavior is ConfigBehavior.DECODER + and self.decoder_input_processor_export is True + ): + dummy_inputs["past_key_values.0.key"] = dummy_inputs.pop("past_key_values") + + return dummy_inputs + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return LlavaModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 0a105343546..348d746d962 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -396,6 +396,204 @@ def __init__( self.original_make_causal = AttentionMaskConverter._make_causal_mask +def patched_merge_input_ids_with_image_features( + self, image_features, inputs_embeds, input_ids, attention_mask, labels +): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + + # Compute the maximum embed dimension + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling + image_to_overwrite = torch.ones((batch_size, max_embed_dim), dtype=torch.bool, device=inputs_embeds.device) + image_to_overwrite[batch_indices, text_to_overwrite] = False + + # ModelPatcher Fix: Exporting the operator 'aten::__iand_' not supported AND cumsum inut should be INT not BOOL + # image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + image_to_overwrite_int = image_to_overwrite.to(final_attention_mask.dtype) + mask = image_to_overwrite_int.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + image_to_overwrite *= mask + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + + # ModelPatcher Fix: Exporting the operator 'aten::__ior_' not supported + # final_attention_mask2 = final_attention_mask | image_to_overwrite + final_attention_mask = torch.max(final_attention_mask, image_to_overwrite.to(final_attention_mask.dtype)) + + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + if labels is None: + final_labels = None + + final_attention_mask = final_attention_mask.to(position_ids.dtype) + + return final_embedding, final_attention_mask, final_labels, position_ids + + +class LlavaModelPatcher(ModelPatcher): + def __enter__(self): + super().__enter__() + self.patch_ops() + setattr( + self._model, + "_merge_input_ids_with_image_features", + types.MethodType(patched_merge_input_ids_with_image_features, self._model), + ) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self.restore_ops() + setattr(self._model, self.orig_forward_name, self.orig_forward) + setattr( + self._model, + "_merge_input_ids_with_image_features", + types.MethodType(self.original_merge_input_ids_with_image_features, self._model), + ) + + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + self.original_merge_input_ids_with_image_features = self._model._merge_input_ids_with_image_features + + def patched_forward( + input_ids=None, + pixel_values=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + inputs_embeds=None, + ): + vision_feature_layer = model.config.vision_feature_layer + vision_feature_select_strategy = model.config.vision_feature_select_strategy + + if config._behavior == "encoder": + inputs_embeds = model.get_input_embeddings()(input_ids) + + image_outputs = model.vision_tower(pixel_values, output_hidden_states=True) + selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + else: + raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}") + + image_features = model.multi_modal_projector(selected_image_feature) + + inputs_embeds, attention_mask, labels, position_ids = model._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, None + ) + + result = { + "inputs_embeds": inputs_embeds, + "decoder_attention_mask": attention_mask, + "position_ids": position_ids, + } + elif config._behavior == "decoder" and config.decoder_input_processor_export is True: + inputs_embeds = model.get_input_embeddings()(input_ids) + + first_layer_past_key_value = past_key_values + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -1:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + + result = { + "inputs_embeds": inputs_embeds, + "decoder_attention_mask": attention_mask, + "position_ids": position_ids, + } + + return result + + if config.variant == "optimized" and ( + config._behavior != "decoder" or config.decoder_input_processor_export is True + ): + self.patched_forward = patched_forward + + def falcon_build_alibi_tensor_patched( attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype ) -> torch.Tensor: diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 608b3df0d7c..945c136a948 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -446,6 +446,10 @@ class TasksManager: "zero-shot-image-classification", onnx="CLIPOnnxConfig", ), + "clip-vision-model": supported_tasks_mapping( + "feature-extraction", + onnx="CLIPVisionOnnxConfig", + ), "codegen": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", @@ -870,6 +874,11 @@ class TasksManager: "text-classification", onnx="LlamaOnnxConfig", ), + "llava": supported_tasks_mapping( + "image-to-text", + "image-to-text-with-past", + onnx="LlavaOnnxConfig", + ), "pegasus": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 74d2d983850..f8bfff56f0b 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -482,6 +482,37 @@ def get_speecht5_models_for_export( return models_for_export +def get_llava_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel"], config: "ExportConfig"): + models_for_export = {} + + if config.variant == "default": + monolith_export_config = config.with_behavior("monolith", use_past=config.use_past, use_past_in_inputs=False) + models_for_export["model"] = (model, monolith_export_config) + + if config.use_past: + decoder_export_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True) + models_for_export[DECODER_NAME] = (model, decoder_export_config_with_past) + elif config.variant == "optimized": + encoder_export_config = config.with_behavior("encoder") + models_for_export[ENCODER_NAME] = (model, encoder_export_config) + + decoder_export_config = config.with_behavior( + "decoder", use_past=config.use_past, use_past_in_inputs=config.use_past + ) + models_for_export[DECODER_NAME] = (model, decoder_export_config) + + if config.use_past: + decoder_preprocess_export_config = config.with_behavior( + "decoder", + use_past=config.use_past, + use_past_in_inputs=config.use_past, + decoder_input_processor_export=True, + ) + models_for_export["decoder_input_processor_model"] = (model, decoder_preprocess_export_config) + + return models_for_export + + def override_diffusers_2_0_attn_processors(model): for _, submodule in model.named_modules(): if isinstance(submodule, Attention): @@ -557,6 +588,8 @@ def _get_submodels_and_export_configs( models_and_export_configs = get_sam_models_for_export(model, export_config) elif model.config.model_type == "speecht5": models_and_export_configs = get_speecht5_models_for_export(model, export_config, model_kwargs) + elif model.config.model_type == "llava": + models_and_export_configs = get_llava_models_for_export(model, export_config) elif model.config.model_type == "musicgen": models_and_export_configs = get_musicgen_models_for_export(model, export_config) else: diff --git a/optimum/onnxruntime/runs/__init__.py b/optimum/onnxruntime/runs/__init__.py index 1d982949344..d21db2a4aca 100644 --- a/optimum/onnxruntime/runs/__init__.py +++ b/optimum/onnxruntime/runs/__init__.py @@ -110,9 +110,9 @@ def __init__(self, run_config): model_class = FeaturesManager.get_model_class_for_feature(get_autoclass_name(self.task)) self.torch_model = model_class.from_pretrained(run_config["model_name_or_path"]) - self.return_body[ - "model_type" - ] = self.torch_model.config.model_type # return_body is initialized in parent class + self.return_body["model_type"] = ( + self.torch_model.config.model_type + ) # return_body is initialized in parent class def _launch_time(self, trial): batch_size = trial.suggest_categorical("batch_size", self.batch_sizes) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index ec27fe8db4b..9f140eb6f87 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -381,6 +381,7 @@ class DummyTextInputGenerator(DummyInputGenerator): SUPPORTED_INPUT_NAMES = ( "input_ids", + "inputs_embeds", "attention_mask", "encoder_attention_mask", "token_type_ids", @@ -401,6 +402,7 @@ def __init__( **kwargs, ): self.task = task + self.hidden_size = normalized_config.hidden_size if isinstance(normalized_config, NormalizedEncoderDecoderConfig): self.vocab_size = normalized_config.vocab_size @@ -435,6 +437,15 @@ def generate( min_value = 0 max_value = 2 if input_name != "input_ids" else self.vocab_size shape = [self.batch_size, self.sequence_length] + + if input_name == "inputs_embeds": + return self.random_float_tensor( + shape=[self.batch_size, self.sequence_length, self.hidden_size], + min_value=0, + max_value=1, + framework=framework, + dtype=float_dtype, + ) if self.task == "multiple-choice": shape = [self.batch_size, self.num_choices, self.sequence_length] if "mask" in input_name: diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 682f70e3ca3..f5c8087af77 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -124,14 +124,15 @@ class NormalizedEncoderDecoderConfig(NormalizedConfig): DECODER_NORMALIZED_CONFIG_CLASS = None def __getattr__(self, attr_name): - if self.ENCODER_NORMALIZED_CONFIG_CLASS is not None and attr_name.upper() in dir( - self.ENCODER_NORMALIZED_CONFIG_CLASS - ): - return self.ENCODER_NORMALIZED_CONFIG_CLASS.__getattr__(attr_name) + # Giving preference to decoder config if self.DECODER_NORMALIZED_CONFIG_CLASS is not None and attr_name.upper() in dir( self.DECODER_NORMALIZED_CONFIG_CLASS ): return self.DECODER_NORMALIZED_CONFIG_CLASS.__getattr__(attr_name) + if self.ENCODER_NORMALIZED_CONFIG_CLASS is not None and attr_name.upper() in dir( + self.ENCODER_NORMALIZED_CONFIG_CLASS + ): + return self.ENCODER_NORMALIZED_CONFIG_CLASS.__getattr__(attr_name) return super().__getattr__(attr_name)