Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolved ALIBI bias regression due to porting flat PA #503

Open
wants to merge 1 commit into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0766759
2 changes: 2 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def __init__(
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
tp_rank: Optional[int] = None,
prev_attn: Optional[torch.nn.Module] = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to to modify this class' constructor here? This will be hard to upstream. TP rank can be obtained using: from vllm.distributed import get_tensor_model_parallel_rank and if I'm understanding this correctly, prev_attn is only used here to reuse alibi bias, which can be generated in each layer separately. Alternatively, it can be probably cached.

) -> None:
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down
187 changes: 164 additions & 23 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def __init__(
kv_cache_dtype: str,
tannervoas742 marked this conversation as resolved.
Show resolved Hide resolved
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
tp_rank: Optional[int] = None,
prev_attn: Optional[torch.nn.Module] = None,
**kwargs,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
Expand All @@ -142,11 +145,42 @@ def __init__(
else ModuleFusedSDPA(HPUFusedSDPA)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
self.alibi_slopes = None
self.prompt_position_bias = None
# Set upper bound on sequence length
self.max_seq_len = int(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change it to max_seq_len_upper_bound or something similar. For me this looks misleading considering that this variable is reused.

os.getenv(
'VLLM_PROMPT_ALIBI_MAX_SEQ_LEN',
max_seq_len,
))
# Set lower bound on sequence length
self.max_seq_len = max([
self.max_seq_len,
int(os.getenv('VLLM_PROMPT_SEQ_BUCKET_MAX', '0')),
])
self.tp_rank = tp_rank
self.prev_attn = None if prev_attn is None else prev_attn.impl
if alibi_slopes is not None:
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
if (self.prev_attn is not None
and self.prev_attn.tp_rank == self.tp_rank):
self.alibi_slopes = self.prev_attn.alibi_slopes
self.prompt_position_bias = self.prev_attn.prompt_position_bias
else:
slope_tensor_dtype = {
True: torch.float32,
False: torch.bfloat16,
}[os.getenv('VLLM_ALIBI_USE_FLOAT32_BIASES', '1').lower()
in ['1', 'true']]
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=slope_tensor_dtype)
self.alibi_slopes = alibi_slopes_tensor
# Creating the prompt_position_bias once and reusing it
# if seq_len permits.
self.prompt_position_bias = _make_prompt_alibi_bias(
alibi_slopes=self.alibi_slopes,
seq_len=self.max_seq_len,
dtype=self.alibi_slopes.dtype,
)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand All @@ -157,6 +191,12 @@ def __init__(
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'
if not self.use_contiguous_pa:
assert alibi_slopes is None, \
'Non-contiguous PA not supported with alibi slopes!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -230,27 +270,58 @@ def forward(
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)

if attn_metadata is None or attn_metadata.block_list is None:
if not self.prefill_use_fusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
# If we have alibi_slopes, incorporate them with
# position_bias and position_bias_offset.
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
seq_lens_tensor = attn_metadata.seq_lens_tensor
position_bias = None
position_bias_offset = None
if (self.prompt_position_bias is not None
and self.alibi_slopes is not None):
if self.max_seq_len >= max(attn_bias.size(-2),
attn_bias.size(-1)):
# Using pre-computed prompt_position_bias subset.
position_bias = self.prompt_position_bias[:, :,
-attn_bias.size(-2):,
-attn_bias.size(-1):]
else:
# For longer sequences than precomputed,
# recreate the bias. This is memory inefficient.
position_bias = _make_prompt_alibi_bias(
alibi_slopes=self.alibi_slopes,
seq_len=max(attn_bias.size(-2),
attn_bias.size(-1)),
dtype=self.alibi_slopes.dtype,
)
# If seq_lens_tensor is provided, we create a
# position_bias_offset. This offset helps handle
# sequences of varying lengths in a batch.
if seq_lens_tensor is not None:
position_bias_offset = seq_lens_tensor.unsqueeze(
1).tile(1, self.num_heads).to(
dtype=self.alibi_slopes.dtype)
position_bias_offset.mul_(
self.alibi_slopes[None, :])
position_bias_offset = position_bias_offset \
- position_bias[:, :, -1, 0]
else:
attn_bias = None
position_bias = None
position_bias_offset = None

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
position_bias=position_bias,
position_bias_offset=position_bias_offset,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
Expand Down Expand Up @@ -278,6 +349,20 @@ def forward(
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
self.position_bias = None
alibi_blocks = attn_metadata.alibi_blocks
if self.alibi_slopes is not None and alibi_blocks is not None:
if (self.prev_attn is not None
and self.prev_attn.tp_rank == self.tp_rank):
self.position_bias = self.prev_attn.position_bias
else:
# For decoding, compute position bias using alibi_blocks.
self.position_bias = _make_decode_alibi_bias(
alibi_blocks=alibi_blocks,
alibi_slopes=self.alibi_slopes,
dtype=self.alibi_slopes.dtype,
)

output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
Expand All @@ -288,14 +373,18 @@ def forward(
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
position_bias=self.position_bias,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
values_fetch_func=self.v_cache.fetch_from_cache,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
output = output.view(batch_size, seq_len, hidden_size)
return output

def forward_encoder_decoder(
self,
Expand Down Expand Up @@ -409,12 +498,25 @@ def forward_encoder_decoder(
return output.view(batch_size, -1, hidden_size)


def _make_alibi_bias(
def _make_prompt_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_len: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Create the ALiBi position bias tensor for prompt stage.
This tensor is reused or tiled as needed for each forward pass.
Does not scale with batch size or number of blocks.

Args:
alibi_slopes: shape = [num_heads]
seq_len: int
dtype: torch.dtype

Returns:
A per-head bias tensor of shape [1, num_heads, seq_len, seq_len].
This bias encodes positional information via ALiBi slopes.
"""
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
Expand All @@ -427,15 +529,54 @@ def _make_alibi_bias(

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
per_head_bias = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
)[:, :, :, :seq_len]
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias
# to all 32 heads in Eager mode.
per_head_bias[:, :] = bias
per_head_bias.mul_(alibi_slopes[:, None, None])

return per_head_bias


def _make_decode_alibi_bias(
alibi_blocks: torch.Tensor,
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Create the ALiBi position bias tensor for decode stage.
Uses stored alibi_blocks and slopes for final scaling.
Scales with number of blocks, not with batch size.

Args:
alibi_blocks: shape = [num_blocks, block_size]
alibi_slopes: shape = [num_heads]
dtype: torch.dtype

Returns:
A per-head bias tensor of shape [num_blocks, num_heads, block_size].
Each row encodes position-dependent ALiBi slopes for decoding steps.
"""
num_heads = alibi_slopes.shape[0]
per_head_bias = torch.empty(
alibi_blocks.size(0),
num_heads,
alibi_blocks.size(-1),
device=alibi_slopes.device,
dtype=dtype,
)
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias
# to all 32 heads in Eager mode.
per_head_bias[:, :] = alibi_blocks.unsqueeze(-2)
per_head_bias.mul_(alibi_slopes[None, :, None])

return per_head_bias
1 change: 1 addition & 0 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
7 changes: 5 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
logits_soft_cap: Optional[int] = 4096,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That default value is already set for our implementation, so I'd rather not change it for others

per_layer_sliding_window: Optional[int] = None,
tp_rank: Optional[int] = None,
prefix: str = "",
prev_attn: Optional[nn.Module] = None,
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
Expand Down Expand Up @@ -96,7 +98,8 @@ def __init__(
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
blocksparse_params, logits_soft_cap,
tp_rank=tp_rank, prev_attn=prev_attn)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata:
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
alibi_blocks: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
Loading
Loading