Skip to content

Commit

Permalink
fix make_causal patching
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Nov 2, 2023
1 parent 967b8c9 commit 3881280
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def falcon_model_forward_without_kv_reformatting(


def _make_causal_mask_patched(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
Expand Down Expand Up @@ -402,6 +401,9 @@ def _make_causal_mask_patched(
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)


_make_causal_mask_patched = staticmethod(_make_causal_mask_patched)


class DecoderModelPatcher(ModelPatcher):
def __enter__(self):
# TODO: Remove this if once transformers if much above 4.35
Expand All @@ -410,11 +412,8 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
# TODO: Remove this if once transformers if much above 4.35
# TODO: We should unpatch it - however `self._make_causal_mask` may still be called later which raises issues with this simple patch strategy.
# We need to find a proper solution.
# if AttentionMaskConverter is not None:
# AttentionMaskConverter._make_causal_mask = self.original_make_causal
pass
if AttentionMaskConverter is not None:
AttentionMaskConverter._make_causal_mask = self.original_make_causal

def __init__(
self,
Expand Down

0 comments on commit 3881280

Please sign in to comment.