Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLava ONNX export #1790

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def variant(self, value: str):
if value == "default" and hasattr(self, "DEFAULT_VARIANT"):
value = self.DEFAULT_VARIANT
if value not in self.VARIANTS:
raise ValueError(f"The variant {value} is not supported for the ONNX config {self.__class__.__name__}.")
raise ValueError(
f"The variant {value} is not supported for the ONNX config {self.__class__.__name__}. Available variants {self.VARIANTS.keys()}"
)
self._variant = value

def fix_dynamic_axes(
Expand Down Expand Up @@ -645,7 +647,8 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
and "attention_mask" in dummy_inputs
):
# Obtain the past sequence length from the value instead of the key (Bloom).
past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[-2]
input_name = "inputs_embeds" if "inputs_embeds" in dummy_inputs else "input_ids"
past_present_length = dummy_inputs[input_name].shape[1] + dummy_inputs["past_key_values"][0][1].shape[-2]

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
Expand Down
266 changes: 264 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Model specific ONNX configurations."""
import random
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union

from packaging import version
from transformers.utils import is_tf_available
Expand Down Expand Up @@ -72,6 +72,7 @@
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
FalconModelPatcher,
LlavaModelPatcher,
MusicgenModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
Expand Down Expand Up @@ -274,7 +275,7 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)


class Qwen2OnnxConfig(LlamaOnnxConfig):
Expand Down Expand Up @@ -976,6 +977,10 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
return dummy_inputs


class CLIPVisionOnnxConfig(ViTOnnxConfig):
pass


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
Expand Down Expand Up @@ -2239,3 +2244,260 @@ def overwrite_shape_and_generate_input(

class EncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig


class LlavaOnnxConfig(OnnxConfigWithPast):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,)
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
DEFAULT_ONNX_OPSET = 14

VARIANTS = {
"default": "The export follows the Transformers implementation of forward in LlavaModelForConditionalGeneration, with the following components exported:\n\t - "
"model.onnx: corresponds to the vision encoder + projection + decoder in a single file without past key value support in https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L360-L519.\n\t - "
"decoder_model.onnx: corresponds to the decoder part in with past_key_values input https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L449-L489.",
"optimized": "The export follows the memory optimized implementation of Transformers forward. This is a recommended export as decoder is exported only once`. It has the following components exported:\n\t - "
"encoder_model.onnx: corresponds to the vision encoder + projection + decoder in https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L421-L445.\n\t - "
"decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L480-L489.\n\t - "
"decoder_input_processor.onnx: corresponds to decoder input generation when past_key_values is provided in https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/llava/modeling_llava.py#L421-L478.",
}

DEFAULT_VARIANT = "optimized"

def __init__(
self,
config: "PretrainedConfig",
task: str = "image-to-text-with-past",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
behavior: ConfigBehavior = ConfigBehavior.MONOLITH,
preprocessors: Optional[List[Any]] = None,
variant: str = "default",
legacy: bool = False,
decoder_input_processor_export: Optional[bool] = None,
):
super().__init__(config, task, int_dtype, float_dtype, use_past, use_past_in_inputs, preprocessors, legacy)

if legacy:
raise ValueError("LLavaOnnxConfig is only supported in legacy mode.")

self._behavior = behavior
self.variant = variant
self.decoder_input_processor_export = decoder_input_processor_export

if variant == "default" and behavior is ConfigBehavior.ENCODER:
raise ValueError(f"LLava does not support encoder-only export for variant {variant}.")

# Local import to avoid circular imports.
from optimum.exporters.tasks import TasksManager

# Set up the encoder ONNX config.
encoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx",
task="feature-extraction",
model_type=config.vision_config.model_type.replace("_", "-"),
library_name="transformers",
)
self._encoder_onnx_config = encoder_onnx_config_constructor(
config.vision_config, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors
)

self._normalized_config.ENCODER_NORMALIZED_CONFIG_CLASS = self._encoder_onnx_config._normalized_config

# Set up the decoder ONNX config.
task = "text-generation-with-past" if use_past else "text-generation"
decoder_onnx_config_constructor = TasksManager.get_exporter_config_constructor(
exporter="onnx",
task="feature-extraction",
model_type=config.text_config.model_type.replace("_", "-"),
library_name="transformers",
)
self._decoder_onnx_config = decoder_onnx_config_constructor(
config.text_config,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
)

self.is_decoder_with_past = issubclass(decoder_onnx_config_constructor.func, OnnxConfigWithPast)
if not self.is_decoder_with_past:
raise ValueError("LLava does not support decoder without past_key_values input.")

self._normalized_config.DECODER_NORMALIZED_CONFIG_CLASS = self._decoder_onnx_config._normalized_config

self.DUMMY_INPUT_GENERATOR_CLASSES += self._decoder_onnx_config.DUMMY_INPUT_GENERATOR_CLASSES
self.DUMMY_PKV_GENERATOR_CLASS = self._decoder_onnx_config.DUMMY_PKV_GENERATOR_CLASS

def with_behavior(
self,
behavior: Union[str, ConfigBehavior],
use_past: bool = False,
use_past_in_inputs: bool = False,
decoder_input_processor_export: Optional[bool] = None,
) -> OnnxConfigWithPast:
if isinstance(behavior, str) and not isinstance(behavior, ConfigBehavior):
behavior = ConfigBehavior(behavior)

onnx_config = self.__class__(
self._config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=self._preprocessors,
variant=self.variant,
legacy=self.legacy,
decoder_input_processor_export=decoder_input_processor_export,
)
return onnx_config

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}

if self.variant == "transformers":
if self._behavior is ConfigBehavior.DECODER:
common_inputs["input_ids"] = {0: "batch_size"}

if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")

elif self.variant == "optimized":
if self._behavior is ConfigBehavior.DECODER:
common_inputs = {
"inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length", 2: "hidden_size"},
"attention_mask": {0: "batch_size", 1: "decoder_sequence_length+past_sequence_length"},
"position_ids": {0: "batch_size", 1: "decoder_sequence_length"},
}

if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")

if self.decoder_input_processor_export is True:
common_inputs.pop("inputs_embeds")
common_inputs.pop("position_ids")
common_inputs["input_ids"] = {0: "batch_size"}
common_inputs["attention_mask"] = common_inputs.pop("attention_mask")

pkv_names = [key for key in common_inputs.keys() if key.startswith("past_key_values")][1:]
for pkv_name in pkv_names:
common_inputs.pop(pkv_name)

return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.variant == "transformers":
outputs = {
"logits": {0: "batch_size", 1: "decoder_sequence_length", 2: "vocab_size"},
}
if self.use_past:
self.add_past_key_values(outputs, direction="outputs")
elif self.variant == "optimized":
if self._behavior is ConfigBehavior.ENCODER:
outputs = {
"inputs_embeds": {0: "batch_size", 1: "decoder_sequence_length", 2: "hidden_size"},
"decoder_attention_mask": {0: "batch_size", 1: "decoder_sequence_length"},
"position_ids": {0: "batch_size", 1: "decoder_sequence_length"},
}
elif self._behavior is ConfigBehavior.DECODER and self.decoder_input_processor_export is True:
outputs = {
"inputs_embeds": {0: "batch_size", 2: "hidden_size"},
"decoder_attention_mask": {0: "batch_size", 1: "past_decoder_sequence_length + 1"},
"position_ids": {0: "batch_size"},
}
elif self._behavior is ConfigBehavior.DECODER:
outputs = {
"logits": {0: "batch_size", 1: "decoder_sequence_length", 2: "vocab_size"},
}
if self.use_past:
self.add_past_key_values(outputs, direction="outputs")

return outputs

def overwrite_shape_and_generate_input(
self, dummy_input_gen: "DummyInputGenerator", input_name: str, framework: str, input_shapes: Dict
):
if self.use_past and self.use_past_in_inputs and input_name == "input_ids":
if self.variant == "default" or (
self.variant == "optimized" and self.decoder_input_processor_export is True
):
sequence_length = dummy_input_gen.sequence_length
# Use a sequence length of 1 when the KV cache is already populated.
dummy_input_gen.sequence_length = 1
dummy_input = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
)
dummy_input_gen.sequence_length = sequence_length
else:
dummy_input = dummy_input_gen.generate(
input_name, framework=framework, int_dtype=self.int_dtype, float_dtype=self.float_dtype
)

return dummy_input

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if self.is_decoder_with_past:
return self._decoder_onnx_config.add_past_key_values(inputs_or_outputs, direction)

def flatten_past_key_values(self, flattened_output, name, idx, t):
if self.is_decoder_with_past:
return self._decoder_onnx_config.flatten_past_key_values(flattened_output, name, idx, t)

def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> Dict[str, Any]:
return self._decoder_onnx_config.flatten_output_collection_property(name, field)

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
self.PAD_ATTENTION_MASK_TO_PAST = self._decoder_onnx_config.PAD_ATTENTION_MASK_TO_PAST

dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

if "pixel_values" in dummy_inputs:
input_ids = dummy_inputs["input_ids"]
mask = input_ids == self._config.image_token_index
input_ids[mask] = self._config.pad_token_id

if self._behavior is ConfigBehavior.MONOLITH or self._behavior is ConfigBehavior.ENCODER:
input_ids[:, 1] = self._config.image_token_index

dummy_inputs["input_ids"] = input_ids

if (
self.variant == "optimized"
and self._behavior is ConfigBehavior.DECODER
and self.decoder_input_processor_export is True
):
dummy_inputs["past_key_values"] = dummy_inputs["past_key_values"][0][0][:, :, :, 0]

return dummy_inputs

def generate_dummy_inputs_for_validation(
self, reference_model_inputs: Dict[str, Any], onnx_input_names: Optional[List[str]] = None
) -> Dict[str, Any]:
dummy_inputs = super().generate_dummy_inputs_for_validation(reference_model_inputs, onnx_input_names)

if self.variant == "default" and self._behavior is ConfigBehavior.DECODER:
dummy_inputs.pop("pixel_values")

if (
self.variant == "optimized"
and self._behavior is ConfigBehavior.DECODER
and self.decoder_input_processor_export is True
):
dummy_inputs["past_key_values.0.key"] = dummy_inputs.pop("past_key_values")

return dummy_inputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return LlavaModelPatcher(self, model, model_kwargs=model_kwargs)
Loading
Loading