From c27491802978a6a776a361ec53874d0d61aa0749 Mon Sep 17 00:00:00 2001 From: Ilia Taraban Date: Sun, 8 Sep 2024 22:30:43 +0200 Subject: [PATCH] Fix attn_bias calculation for alibi --- vllm/attention/backends/habana_attn.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 20b0f2bc7630b..56b71a431aca7 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -108,17 +108,10 @@ def __init__( self.v_cache = VLLMKVCache() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window - self.position_bias = None self.alibi_slopes = alibi_slopes if alibi_slopes is not None: - # FIXME(kzawora): Need a general method to set max_seq_len on - # per-model basis. alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=torch.bfloat16) - self.position_bias = _make_alibi_bias(alibi_slopes_tensor, - num_kv_heads, - alibi_slopes_tensor.dtype, - max_seq_len) self.alibi_slopes = alibi_slopes_tensor assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -190,11 +183,13 @@ def forward( assert attn_metadata.attn_bias is not None, \ 'attn_bias must be set before calling model.forward!' attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None and \ - self.position_bias is not None: - attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) + if self.alibi_slopes is not None: + position_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, + attn_bias.dtype, + attn_bias.shape[-1]) + attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) + attn_bias.add_(position_bias) else: attn_bias = None