Skip to content

Commit

Permalink
Remove Falcon attention mask patching
Browse files Browse the repository at this point in the history
  • Loading branch information
baskrahmer committed Oct 31, 2023
1 parent 30a922c commit 1c1a4be
Showing 1 changed file with 1 addition and 15 deletions.
16 changes: 1 addition & 15 deletions optimum/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,6 @@ def _falcon_prepare_attn_mask(
f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
f" {past_key_values_length}."
)
combined_attention_mask = None
device = attention_mask.device
_, seq_length = input_shape

# if seq_length > 1:
# NOTE: we remove here the `if seq_length > 1` to allow to use a single decoder.
combined_attention_mask = _make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)

# [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)

return combined_attention_mask
return _expand_mask(attention_mask, past_key_values_length=past_key_values_length)

0 comments on commit 1c1a4be

Please sign in to comment.