Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 17, 2023
1 parent 14e2ad8 commit 54aa31e
Showing 1 changed file with 40 additions and 7 deletions.
47 changes: 40 additions & 7 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dataclasses
import functools
import inspect
import types
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union

import transformers
Expand Down Expand Up @@ -378,6 +379,42 @@ def falcon_model_forward_without_kv_reformatting(


class FalconModelPatcher(ModelPatcher):
def __enter__(self):
self.patch_ops()

transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched

if self.real_config.task == "text-generation":
self._model.transformer.forward = types.MethodType(
falcon_model_forward_without_kv_reformatting, self._model.transformer
)

# In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length.
if isinstance(self._model, FalconModel):
self._model._prepare_attn_mask = _falcon_prepare_attn_mask
else:
self._model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask

setattr(self._model, self.orig_forward_name, self.patched_forward)

def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()

setattr(self._model, self.orig_forward_name, self.orig_forward)

if self.real_config.task == "text-generation":
self._model.transformer.forward = types.MethodType(
self.original_model_transformer_forward, self._model.transformer
)

transformers.models.falcon.modeling_falcon._make_causal_mask = self.original_make_causal

# In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length.
if isinstance(self._model, FalconModel):
self._model._prepare_attn_mask = self.original_falcon_prepare_attn_mask
else:
self._model.transformer._prepare_attn_mask = self.original_falcon_prepare_attn_mask

def __init__(
self,
config: "OnnxConfig",
Expand All @@ -386,19 +423,15 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

# This is kind of ugly and bug prone if other FalconModel are instantiated.
if config.task == "text-generation":
model.transformer.__class__.forward = falcon_model_forward_without_kv_reformatting
self.original_model_transformer_forward = model.transformer.forward

self.original_make_causal = transformers.models.falcon.modeling_falcon._make_causal_mask

transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched

# In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length.
if isinstance(model, FalconModel):
model._prepare_attn_mask = _falcon_prepare_attn_mask
self.original_falcon_prepare_attn_mask = model._prepare_attn_mask
else:
model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask
self.original_falcon_prepare_attn_mask = model.transformer._prepare_attn_mask

self._model = model

Expand Down

0 comments on commit 54aa31e

Please sign in to comment.