Skip to content

Commit

Permalink
Support transformers 4.43 (#1971)
Browse files Browse the repository at this point in the history
* fix bt bark test

* setup

* patch clip models for sd

* infer ort model dtype property from inputs dtypes

* patch all clip variants

* device setter

* bigger model for now

* fix device attribution

* onnx opset for owlvit and owlv2

* model dtype

* revert

* use model part dtype instead

* no need for dtype with diffusion pipelines

* revert

* fix clip text model with projection not outputting hidden states

* whisper generation

* fix whisper, support cache_position, and using transformers whisper generation loop

* style

* create cache position for merged decoder and fix test for non whisper speech to text

* typo

* conditioned cache position argument

* update whisper min transformers version

* compare whisper ort generation with transformers

* fix generation length for speech to text model type

* cache position in whisper only with dynamic axis decoder_sequence_length

* use minimal prepare_inputs_for_generation in ORTModelForSpeechSeq2Seq

* remove version restrictions on whisper

* comment

* fix

* simpler

---------

Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
IlyasMoutawwakil and echarlaix authored Aug 5, 2024
1 parent 2a6d857 commit f8f9707
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 495 deletions.
4 changes: 1 addition & 3 deletions optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,10 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
self.add_past_key_values(common_inputs, direction="inputs")
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}

Expand Down
42 changes: 40 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NormalizedTextConfig,
NormalizedTextConfigWithGQA,
NormalizedVisionConfig,
check_if_transformers_greater,
is_diffusers_available,
logging,
)
Expand All @@ -71,6 +72,7 @@
)
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
CLIPModelPatcher,
FalconModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
Expand Down Expand Up @@ -913,10 +915,16 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPOnnxConfig(TextAndVisionOnnxConfig):
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -935,6 +943,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
"image_embeds": {0: "image_batch_size"},
}

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class SentenceTransformersCLIPOnnxConfig(CLIPOnnxConfig):
@property
Expand Down Expand Up @@ -980,6 +995,13 @@ def outputs(self) -> Dict[str, Dict[int, str]]:

return common_outputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
@property
Expand All @@ -997,12 +1019,20 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

# TODO: fix should be by casting inputs during inference and not export
if framework == "pt":
import torch

dummy_inputs["input_ids"] = dummy_inputs["input_ids"].to(dtype=torch.int32)
return dummy_inputs

def patch_model_for_export(
self,
model: Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"],
model_kwargs: Optional[Dict[str, Any]] = None,
) -> "ModelPatcher":
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
Expand Down Expand Up @@ -1135,6 +1165,9 @@ class OwlViTOnnxConfig(CLIPOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
MIN_TORCH_VERSION = version.parse("2.1")

# needs einsum operator support, available since opset 12
DEFAULT_ONNX_OPSET = 12

def __init__(
self,
config: "PretrainedConfig",
Expand Down Expand Up @@ -1438,7 +1471,12 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_features"] = {0: "batch_size"} # Remove unnecessary dynamic axis.

if self._behavior is ConfigBehavior.DECODER and self.use_past_in_inputs is False:
if self._behavior is not ConfigBehavior.ENCODER and self.use_past_in_inputs:
if check_if_transformers_greater("4.43.0"):
# since https://github.com/huggingface/transformers/pull/31166
common_inputs["cache_position"] = {0: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER and not self.use_past_in_inputs:
common_inputs["encoder_outputs"][1] = f"{common_inputs['encoder_outputs'][1]} / 2"
return common_inputs

Expand Down
17 changes: 17 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,3 +1138,20 @@ def __init__(
self._update_causal_mask_original = self._model.model._update_causal_mask
else:
self._update_causal_mask_original = self._model._update_causal_mask


class CLIPModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()

if _transformers_version >= version.parse("4.43"):
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention

self.original_sdpa_forward, CLIPSdpaAttention.forward = CLIPSdpaAttention.forward, CLIPAttention.forward

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if _transformers_version >= version.parse("4.43"):
from transformers.models.clip.modeling_clip import CLIPSdpaAttention

CLIPSdpaAttention.forward = self.original_sdpa_forward
10 changes: 6 additions & 4 deletions optimum/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _get_submodels_for_export_diffusion(
pipeline, (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline)
)
is_stable_diffusion_xl = isinstance(
pipeline, (StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline)
pipeline, (StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline)
)
is_latent_consistency_model = isinstance(
pipeline, (LatentConsistencyModelPipeline, LatentConsistencyModelImg2ImgPipeline)
Expand All @@ -117,10 +117,11 @@ def _get_submodels_for_export_diffusion(
models_for_export = {}

# Text encoder
if pipeline.text_encoder is not None:
text_encoder = getattr(pipeline, "text_encoder", None)
if text_encoder is not None:
if is_stable_diffusion_xl:
pipeline.text_encoder.config.output_hidden_states = True
models_for_export["text_encoder"] = pipeline.text_encoder
text_encoder.config.output_hidden_states = True
models_for_export["text_encoder"] = text_encoder

# U-NET
# ONNX export of torch.nn.functional.scaled_dot_product_attention not supported for < v2.1.0
Expand Down Expand Up @@ -151,6 +152,7 @@ def _get_submodels_for_export_diffusion(
text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2.config.output_hidden_states = True
text_encoder_2.text_model.config.output_hidden_states = True
models_for_export["text_encoder_2"] = text_encoder_2

return models_for_export
Expand Down
50 changes: 38 additions & 12 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from ..utils import NormalizedConfigManager
from ..utils.logging import warn_once
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
from .utils import get_ordered_input_names, logging

Expand Down Expand Up @@ -62,6 +63,20 @@ def __init__(
def device(self):
return self.parent_model.device

@property
def dtype(self):
for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None

@abstractmethod
def forward(self, *args, **kwargs):
pass
Expand Down Expand Up @@ -220,6 +235,7 @@ def forward(
encoder_attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
use_cache_branch: None = None,
) -> Seq2SeqLMOutput:
# Adding use_cache_branch in the signature here is just a hack for IO Binding
Expand All @@ -236,8 +252,8 @@ def forward(
# no-ops if merged decoder is not used
use_merged_no_cache = past_key_values is None and self.parent_model.use_merged
use_merged_cache = past_key_values is not None and self.parent_model.use_merged
use_cache_branch_tensor, past_key_values = self.prepare_inputs_for_merged(
input_ids, past_key_values, use_torch=use_torch
use_cache_branch_tensor, past_key_values, cache_position = self.prepare_inputs_for_merged(
input_ids, past_key_values, cache_position, use_torch=use_torch
)

if self.parent_model.use_io_binding:
Expand Down Expand Up @@ -274,6 +290,9 @@ def forward(
if use_cache_branch_tensor is not None:
model_inputs.append(use_cache_branch_tensor)

if "cache_position" in self.input_names:
model_inputs.append(cache_position)

io_binding, output_shapes, output_buffers = self.parent_model._prepare_io_binding(
self.session,
*model_inputs,
Expand Down Expand Up @@ -346,6 +365,7 @@ def forward(
"decoder_attention_mask": decoder_attention_mask,
"encoder_attention_mask": encoder_attention_mask,
"use_cache_branch": use_cache_branch_tensor,
"cache_position": cache_position,
"labels": labels,
}
if past_key_values is not None:
Expand Down Expand Up @@ -405,20 +425,20 @@ def forward(

def prepare_inputs_for_merged(
self,
input_ids: Union[None, torch.LongTensor, np.ndarray],
past_key_values: Union[None, Tuple[torch.FloatTensor], Tuple[np.ndarray]],
input_ids: Optional[Union[torch.LongTensor, np.ndarray]],
past_key_values: Optional[Tuple[Union[torch.FloatTensor, np.ndarray]]],
cache_position: Optional[Union[torch.Tensor, np.ndarray]],
use_torch: bool,
):
constructor = torch if use_torch is True else np

if self.parent_model.use_merged:
constructor = torch if use_torch is True else np
# Uses without/with branch of a merged decoder depending on whether real past key values are passed
use_cache_branch = constructor.full((1,), past_key_values is not None)
use_cache_branch_tensor = constructor.full((1,), past_key_values is not None)
if use_torch and use_cache_branch_tensor is not None:
use_cache_branch_tensor = use_cache_branch_tensor.to(self.device)
else:
# Uses separate decoders
use_cache_branch = None

if use_torch and use_cache_branch is not None:
use_cache_branch = use_cache_branch.to(self.device)
use_cache_branch_tensor = None

# Generate dummy past for the first forward if uses a merged decoder
if self.parent_model.use_merged and past_key_values is None:
Expand All @@ -434,7 +454,13 @@ def prepare_inputs_for_merged(

past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))

return use_cache_branch, past_key_values
# Generate dummy position cache for the first forward if uses a merged decoder
if self.parent_model.use_merged and cache_position is None:
cache_position = constructor.zeros((1,), dtype=constructor.int64)
if use_torch is True:
cache_position = cache_position.to(self.device)

return use_cache_branch_tensor, past_key_values, cache_position


class ORTDecoder(ORTDecoderForSeq2Seq):
Expand Down
8 changes: 7 additions & 1 deletion optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,14 @@ def to(self, device: Union[torch.device, str, int]):
Returns:
`ORTModel`: the model placed on the requested device.
"""

device, provider_options = parse_device(device)
provider = get_provider_for_device(device)
validate_provider_availability(provider) # raise error if the provider is not available
self.device = device

if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self

self.vae_decoder.session.set_providers([provider], provider_options=[provider_options])
self.text_encoder.session.set_providers([provider], provider_options=[provider_options])
self.unet.session.set_providers([provider], provider_options=[provider_options])
Expand All @@ -464,6 +468,8 @@ def to(self, device: Union[torch.device, str, int]):
self.vae_encoder.session.set_providers([provider], provider_options=[provider_options])

self.providers = self.vae_decoder.session.get_providers()
self._device = device

return self

@classmethod
Expand Down
28 changes: 23 additions & 5 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,24 @@ def __init__(

self._ordered_input_names = get_ordered_input_names(self.input_names.keys(), func=self.forward)

# TODO: why do we make device a property since we are only access the value, and do not do any check when setting the value?
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the model.
"""

for dtype in self.input_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

for dtype in self.output_dtypes.values():
torch_dtype = TypeHelper.ort_type_to_torch_type(dtype)
if torch_dtype.is_floating_point:
return torch_dtype

return None

@property
def device(self) -> torch.device:
"""
Expand All @@ -286,8 +303,8 @@ def device(self) -> torch.device:
return self._device

@device.setter
def device(self, value: torch.device):
self._device = value
def device(self, **kwargs):
raise AttributeError("The device attribute is read-only, please use the `to` method to change the device.")

@property
def use_io_binding(self):
Expand All @@ -309,13 +326,13 @@ def to(self, device: Union[torch.device, str, int]):
Returns:
`ORTModel`: the model placed on the requested device.
"""

device, provider_options = parse_device(device)

if device.type == "cuda" and self.providers[0] == "TensorrtExecutionProvider":
return self

self.device = device
provider = get_provider_for_device(self.device)
provider = get_provider_for_device(device)
validate_provider_availability(provider) # raise error if the provider is not available

# IOBinding is only supported for CPU and CUDA Execution Providers.
Expand All @@ -331,6 +348,7 @@ def to(self, device: Union[torch.device, str, int]):

self.model.set_providers([provider], provider_options=[provider_options])
self.providers = self.model.get_providers()
self._device = device

return self

Expand Down
Loading

0 comments on commit f8f9707

Please sign in to comment.