Skip to content

Commit

Permalink
remove unused method
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Nov 2, 2023
1 parent 3881280 commit 05fc5f7
Showing 1 changed file with 0 additions and 25 deletions.
25 changes: 0 additions & 25 deletions optimum/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

import functools

import torch


MODEL_TO_PATCH_FOR_PAST = {
"bart",
Expand Down Expand Up @@ -54,26 +52,3 @@ def recurse_setattr(module, name, value):
else:
name, rest = name.split(".", 1)
recurse_setattr(getattr(module, name), rest, value)


# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
device: torch.device,
past_key_values_length: int,
dtype: torch.dtype = torch.bool,
) -> torch.BoolTensor:
"""
Make causal mask used for bi-directional self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device)
seq_ids = torch.arange(target_length, device=device)

mask[:, past_key_values_length:] = (
(seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min
if torch.is_floating_point(mask)
else seq_ids[:, None] < seq_ids[None, :]
)

return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)

0 comments on commit 05fc5f7

Please sign in to comment.