-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Resolved ALIBI bias regression due to porting flat PA #503
base: habana_main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -125,6 +125,9 @@ def __init__( | |
kv_cache_dtype: str, | ||
tannervoas742 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
blocksparse_params: Optional[Dict[str, Any]] = None, | ||
max_seq_len: int = 4096, | ||
tp_rank: Optional[int] = None, | ||
prev_attn: Optional[torch.nn.Module] = None, | ||
**kwargs, | ||
) -> None: | ||
super(AttentionImpl, self).__init__() | ||
self.kv_cache_dtype = kv_cache_dtype | ||
|
@@ -142,11 +145,42 @@ def __init__( | |
else ModuleFusedSDPA(HPUFusedSDPA) | ||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | ||
self.sliding_window = sliding_window | ||
self.alibi_slopes = alibi_slopes | ||
self.alibi_slopes = None | ||
self.prompt_position_bias = None | ||
# Set upper bound on sequence length | ||
self.max_seq_len = int( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please change it to |
||
os.getenv( | ||
'VLLM_PROMPT_ALIBI_MAX_SEQ_LEN', | ||
max_seq_len, | ||
)) | ||
# Set lower bound on sequence length | ||
self.max_seq_len = max([ | ||
self.max_seq_len, | ||
int(os.getenv('VLLM_PROMPT_SEQ_BUCKET_MAX', '0')), | ||
]) | ||
self.tp_rank = tp_rank | ||
self.prev_attn = None if prev_attn is None else prev_attn.impl | ||
if alibi_slopes is not None: | ||
alibi_slopes_tensor = torch.tensor(alibi_slopes, | ||
dtype=torch.bfloat16) | ||
self.alibi_slopes = alibi_slopes_tensor | ||
if (self.prev_attn is not None | ||
and self.prev_attn.tp_rank == self.tp_rank): | ||
self.alibi_slopes = self.prev_attn.alibi_slopes | ||
self.prompt_position_bias = self.prev_attn.prompt_position_bias | ||
else: | ||
slope_tensor_dtype = { | ||
True: torch.float32, | ||
False: torch.bfloat16, | ||
}[os.getenv('VLLM_ALIBI_USE_FLOAT32_BIASES', '1').lower() | ||
in ['1', 'true']] | ||
alibi_slopes_tensor = torch.tensor(alibi_slopes, | ||
dtype=slope_tensor_dtype) | ||
self.alibi_slopes = alibi_slopes_tensor | ||
# Creating the prompt_position_bias once and reusing it | ||
# if seq_len permits. | ||
self.prompt_position_bias = _make_prompt_alibi_bias( | ||
alibi_slopes=self.alibi_slopes, | ||
seq_len=self.max_seq_len, | ||
dtype=self.alibi_slopes.dtype, | ||
) | ||
assert self.num_heads % self.num_kv_heads == 0 | ||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
|
||
|
@@ -157,6 +191,12 @@ def __init__( | |
assert alibi_slopes is None, \ | ||
'Prefill with FusedSDPA not supported with alibi slopes!' | ||
|
||
self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA', | ||
'true').lower() == 'true' | ||
if not self.use_contiguous_pa: | ||
assert alibi_slopes is None, \ | ||
'Non-contiguous PA not supported with alibi slopes!' | ||
|
||
suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes() | ||
if head_size not in suppored_head_sizes: | ||
raise ValueError( | ||
|
@@ -230,27 +270,58 @@ def forward( | |
query_shape = (batch_size, seq_len, self.num_heads, self.head_size) | ||
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, | ||
self.head_size) | ||
|
||
if attn_metadata is None or attn_metadata.block_list is None: | ||
if not self.prefill_use_fusedsdpa: | ||
# TODO: move this outside of model | ||
assert attn_metadata.attn_bias is not None, \ | ||
'attn_bias must be set before calling model.forward' | ||
# If we have alibi_slopes, incorporate them with | ||
# position_bias and position_bias_offset. | ||
attn_bias = attn_metadata.attn_bias | ||
if self.alibi_slopes is not None: | ||
position_bias = _make_alibi_bias( | ||
self.alibi_slopes, self.num_kv_heads, | ||
attn_bias.dtype, attn_bias.shape[-1]) | ||
attn_bias = attn_bias.tile( | ||
(1, self.num_kv_heads, 1, 1)) | ||
attn_bias.add_(position_bias) | ||
seq_lens_tensor = attn_metadata.seq_lens_tensor | ||
position_bias = None | ||
position_bias_offset = None | ||
if (self.prompt_position_bias is not None | ||
and self.alibi_slopes is not None): | ||
if self.max_seq_len >= max(attn_bias.size(-2), | ||
attn_bias.size(-1)): | ||
# Using pre-computed prompt_position_bias subset. | ||
position_bias = self.prompt_position_bias[:, :, | ||
-attn_bias.size(-2):, | ||
-attn_bias.size(-1):] | ||
else: | ||
# For longer sequences than precomputed, | ||
# recreate the bias. This is memory inefficient. | ||
position_bias = _make_prompt_alibi_bias( | ||
alibi_slopes=self.alibi_slopes, | ||
seq_len=max(attn_bias.size(-2), | ||
attn_bias.size(-1)), | ||
dtype=self.alibi_slopes.dtype, | ||
) | ||
# If seq_lens_tensor is provided, we create a | ||
# position_bias_offset. This offset helps handle | ||
# sequences of varying lengths in a batch. | ||
if seq_lens_tensor is not None: | ||
position_bias_offset = seq_lens_tensor.unsqueeze( | ||
1).tile(1, self.num_heads).to( | ||
dtype=self.alibi_slopes.dtype) | ||
position_bias_offset.mul_( | ||
self.alibi_slopes[None, :]) | ||
position_bias_offset = position_bias_offset \ | ||
- position_bias[:, :, -1, 0] | ||
else: | ||
attn_bias = None | ||
position_bias = None | ||
position_bias_offset = None | ||
|
||
out = ops.prompt_attention( | ||
query.view(query_shape), | ||
key.view(kv_shape), | ||
value.view(kv_shape), | ||
attn_bias=attn_bias, | ||
position_bias=position_bias, | ||
position_bias_offset=position_bias_offset, | ||
p=0.0, | ||
scale=self.scale, | ||
matmul_qk_op=self.matmul_qk, | ||
|
@@ -278,6 +349,20 @@ def forward( | |
output = out.reshape(batch_size, seq_len, hidden_size) | ||
else: | ||
# Decoding run. | ||
self.position_bias = None | ||
alibi_blocks = attn_metadata.alibi_blocks | ||
if self.alibi_slopes is not None and alibi_blocks is not None: | ||
if (self.prev_attn is not None | ||
and self.prev_attn.tp_rank == self.tp_rank): | ||
self.position_bias = self.prev_attn.position_bias | ||
else: | ||
# For decoding, compute position bias using alibi_blocks. | ||
self.position_bias = _make_decode_alibi_bias( | ||
alibi_blocks=alibi_blocks, | ||
alibi_slopes=self.alibi_slopes, | ||
dtype=self.alibi_slopes.dtype, | ||
) | ||
|
||
output = HPUPagedAttention.forward_decode( | ||
query=query, | ||
key_cache=key_cache, | ||
|
@@ -288,14 +373,18 @@ def forward( | |
block_scales=attn_metadata.block_scales, | ||
block_groups=attn_metadata.block_groups, | ||
scale=self.scale, | ||
position_bias=self.position_bias, | ||
matmul_qk_op=self.matmul_qk, | ||
matmul_av_op=self.matmul_av, | ||
batch2block_matmul_op=self.batch2block_matmul, | ||
block2batch_matmul_op=self.block2batch_matmul, | ||
keys_fetch_func=self.k_cache.fetch_from_cache, | ||
values_fetch_func=self.v_cache.fetch_from_cache) | ||
values_fetch_func=self.v_cache.fetch_from_cache, | ||
) | ||
|
||
# Reshape the output tensor. | ||
return output.view(batch_size, seq_len, hidden_size) | ||
output = output.view(batch_size, seq_len, hidden_size) | ||
return output | ||
|
||
def forward_encoder_decoder( | ||
self, | ||
|
@@ -409,12 +498,25 @@ def forward_encoder_decoder( | |
return output.view(batch_size, -1, hidden_size) | ||
|
||
|
||
def _make_alibi_bias( | ||
def _make_prompt_alibi_bias( | ||
alibi_slopes: torch.Tensor, | ||
num_kv_heads: int, | ||
dtype: torch.dtype, | ||
seq_len: int, | ||
dtype: torch.dtype, | ||
) -> torch.Tensor: | ||
""" | ||
Create the ALiBi position bias tensor for prompt stage. | ||
This tensor is reused or tiled as needed for each forward pass. | ||
Does not scale with batch size or number of blocks. | ||
|
||
Args: | ||
alibi_slopes: shape = [num_heads] | ||
seq_len: int | ||
dtype: torch.dtype | ||
|
||
Returns: | ||
A per-head bias tensor of shape [1, num_heads, seq_len, seq_len]. | ||
This bias encodes positional information via ALiBi slopes. | ||
""" | ||
bias = torch.arange(seq_len, dtype=dtype) | ||
# NOTE(zhuohan): HF uses | ||
# `bias = bias[None, :].repeat(seq_len, 1)` | ||
|
@@ -427,15 +529,54 @@ def _make_alibi_bias( | |
|
||
padded_len = (seq_len + 7) // 8 * 8 | ||
num_heads = alibi_slopes.shape[0] | ||
bias = torch.empty( | ||
1, # batch size | ||
per_head_bias = torch.empty( | ||
1, | ||
num_heads, | ||
seq_len, | ||
padded_len, | ||
device=alibi_slopes.device, | ||
dtype=dtype, | ||
)[:, :, :, :seq_len].copy_(bias) | ||
bias.mul_(alibi_slopes[:, None, None]) | ||
if num_heads != num_kv_heads: | ||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) | ||
return bias | ||
)[:, :, :, :seq_len] | ||
# NOTE(Tanner): | ||
# .copy_ was not performing broadcasting of bias | ||
# to all 32 heads in Eager mode. | ||
per_head_bias[:, :] = bias | ||
per_head_bias.mul_(alibi_slopes[:, None, None]) | ||
|
||
return per_head_bias | ||
|
||
|
||
def _make_decode_alibi_bias( | ||
alibi_blocks: torch.Tensor, | ||
alibi_slopes: torch.Tensor, | ||
dtype: torch.dtype, | ||
) -> torch.Tensor: | ||
""" | ||
Create the ALiBi position bias tensor for decode stage. | ||
Uses stored alibi_blocks and slopes for final scaling. | ||
Scales with number of blocks, not with batch size. | ||
|
||
Args: | ||
alibi_blocks: shape = [num_blocks, block_size] | ||
alibi_slopes: shape = [num_heads] | ||
dtype: torch.dtype | ||
|
||
Returns: | ||
A per-head bias tensor of shape [num_blocks, num_heads, block_size]. | ||
Each row encodes position-dependent ALiBi slopes for decoding steps. | ||
""" | ||
num_heads = alibi_slopes.shape[0] | ||
per_head_bias = torch.empty( | ||
alibi_blocks.size(0), | ||
num_heads, | ||
alibi_blocks.size(-1), | ||
device=alibi_slopes.device, | ||
dtype=dtype, | ||
) | ||
# NOTE(Tanner): | ||
# .copy_ was not performing broadcasting of bias | ||
# to all 32 heads in Eager mode. | ||
per_head_bias[:, :] = alibi_blocks.unsqueeze(-2) | ||
per_head_bias.mul_(alibi_slopes[None, :, None]) | ||
|
||
return per_head_bias |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,9 +38,11 @@ def __init__( | |
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
blocksparse_params: Optional[Dict[str, Any]] = None, | ||
logits_soft_cap: Optional[float] = None, | ||
logits_soft_cap: Optional[int] = 4096, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That default value is already set for our implementation, so I'd rather not change it for others |
||
per_layer_sliding_window: Optional[int] = None, | ||
tp_rank: Optional[int] = None, | ||
prefix: str = "", | ||
prev_attn: Optional[nn.Module] = None, | ||
) -> None: | ||
super().__init__() | ||
if per_layer_sliding_window is not None: | ||
|
@@ -96,7 +98,8 @@ def __init__( | |
impl_cls = attn_backend.get_impl_cls() | ||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, | ||
alibi_slopes, sliding_window, kv_cache_dtype, | ||
blocksparse_params, logits_soft_cap) | ||
blocksparse_params, logits_soft_cap, | ||
tp_rank=tp_rank, prev_attn=prev_attn) | ||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
self.num_kv_heads = num_kv_heads | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to to modify this class' constructor here? This will be hard to upstream. TP rank can be obtained using:
from vllm.distributed import get_tensor_model_parallel_rank
and if I'm understanding this correctly, prev_attn is only used here to reuse alibi bias, which can be generated in each layer separately. Alternatively, it can be probably cached.