Skip to content

Commit

Permalink
Support transformers 4.42 (#1929)
Browse files Browse the repository at this point in the history
* support transformers 4.42

* fix mistral

* update opsets

* fix _supports_cache_class

* typo

* nit

* remove onnxruntime extra in ci

* fix
  • Loading branch information
fxmarty authored Jul 1, 2024
1 parent a5500c7 commit 86adc3e
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_onnx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .[tests,onnxruntime,exporters-tf]
pip install .[tests,exporters]
- name: Test with unittest
working-directory: tests
run: |
Expand Down
10 changes: 8 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
FalconModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
Expand Down Expand Up @@ -237,7 +238,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:


class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")


Expand All @@ -259,7 +260,7 @@ class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):


class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


Expand Down Expand Up @@ -312,6 +313,11 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)


class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
Expand Down
197 changes: 197 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
_prepare_4d_causal_attention_mask_for_sdpa = None
AttentionMaskConverter = None

if _transformers_version >= version.parse("4.42"):
from transformers.cache_utils import SlidingWindowCache, StaticCache

if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel

Expand Down Expand Up @@ -746,6 +749,20 @@ def patched_forward(


class SentenceTransformersTransformerPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral":
self._model[0].auto_model._update_causal_mask = types.MethodType(
_update_causal_mask_patched, self._model[0].auto_model
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral":
self._model[0].auto_model._update_causal_mask = types.MethodType(
self._update_causal_mask_original, self._model[0].auto_model
)

def __init__(
self,
config: "OnnxConfig",
Expand All @@ -754,6 +771,8 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask

def patched_forward(input_ids, attention_mask):
result = self.orig_forward({"input_ids": input_ids, "attention_mask": attention_mask})

Expand Down Expand Up @@ -931,3 +950,181 @@ def patched_forward(
return {"audio_values": audio_values}

self.patched_forward = patched_forward


def _update_causal_mask_patched(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values,
use_cache: bool,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.

# cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)

if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
target_length = max(sequence_length, self.config.sliding_window)
# StaticCache
elif using_static_cache:
target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None:
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
# ---------------- NOTE: This part is patched -----------------------------
exclude_mask.bitwise_or_(
torch.arange(target_length, device=device)
<= (cache_position.reshape(-1, 1) - self.config.sliding_window)
)
# ---------------- NOTE: patch end ----------------------------------------

causal_mask *= exclude_mask
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

# if (
# self.config._attn_implementation == "sdpa"
# and attention_mask is not None
# and attention_mask.device.type == "cuda"
# and not output_attentions
# ):
# # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# # Details: https://github.com/pytorch/pytorch/issues/110213
# causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask


class MistralModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if AttentionMaskConverter is not None:
# TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35
AttentionMaskConverter._make_causal_mask = _make_causal_mask_patched_staticmethod

if _transformers_version >= version.parse("4.36"):
AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod

if _transformers_version >= version.parse("4.36"):
patch_everywhere(
"_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched
)

if _transformers_version >= version.parse("4.42"):
if hasattr(self._model, "model"):
self._model.model._update_causal_mask = types.MethodType(
_update_causal_mask_patched, self._model.model
)
else:
self._model._update_causal_mask = types.MethodType(_update_causal_mask_patched, self._model)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if AttentionMaskConverter is not None:
# TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35
AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal)

if _transformers_version >= version.parse("4.36"):
AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended)

if _transformers_version >= version.parse("4.36"):
patch_everywhere(
"_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa
)

if _transformers_version >= version.parse("4.42"):
if hasattr(self._model, "model"):
self._model.model._update_causal_mask = types.MethodType(
self._update_causal_mask_original, self._model.model
)
else:
self._model._update_causal_mask = types.MethodType(self._update_causal_mask_original, self._model)

def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

if _transformers_version >= version.parse("4.36"):
self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa
self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended

# TODO: Remove this if once transformers if much above 4.35
if AttentionMaskConverter is not None:
self.original_make_causal = AttentionMaskConverter._make_causal_mask

if _transformers_version >= version.parse("4.42"):
if hasattr(self._model, "model"):
self._update_causal_mask_original = self._model.model._update_causal_mask
else:
self._update_causal_mask_original = self._model._update_causal_mask
1 change: 1 addition & 0 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class ORTModelForCausalLM(ORTModel, GenerationMixin):

auto_model_class = AutoModelForCausalLM
main_input_name = "input_ids"
_supports_cache_class = False

def __init__(
self,
Expand Down
5 changes: 3 additions & 2 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,9 +1092,10 @@ def forward(
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)

if "last_hidden_state" in self.output_names:
last_hidden_state = model_outputs[self.output_names["last_hidden_state"]]
last_hidden_state = model_outputs["last_hidden_state"]
else:
last_hidden_state = model_outputs[0]
# TODO: This allows to support sentence-transformers models (sentence embedding), but is not validated.
last_hidden_state = next(iter(model_outputs.values()))

# converts output to namedtuple for pipelines post-processing
return BaseModelOutput(last_hidden_state=last_hidden_state)
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ class ORTModelForConditionalGeneration(ORTModel, ABC):

# Used in from_transformers to export model to onnxORTEncoder
base_model_prefix = "onnx_model"
_supports_cache_class = False

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
REQUIRED_PKGS = [
"coloredlogs",
"sympy",
"transformers[sentencepiece]>=4.26.0,<4.42.0",
"transformers[sentencepiece]>=4.26.0,<4.43.0",
"torch>=1.11",
"packaging",
"numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569
Expand Down

0 comments on commit 86adc3e

Please sign in to comment.