Skip to content

Commit

Permalink
vLLM-Base: Full enabling of ALiBi
Browse files Browse the repository at this point in the history
Changes:
- Added back alibi biases to decode stage.
- Optimized ALiBI memory usage.
  - Added environment variable "VLLM_PROMPT_ALIBI_MAX_SEQ_LEN" to allow
    large models to run with restricted prompt lengths.
  - Prompt biases instantiated once in __init__ rather than each
    forward.
  - Prompt and decode biases are shared across encoder/decoder layers.
- Added environment variable "VLLM_ALIBI_USE_FLOAT32_BIASES" to resolve
  accuracy issue on long sequences.
- Updated jais, mpt, falcon, baichuan, and bloom to work with ALiBI.
  - Due to bloom's 176B parameter size I was unable to test this model.
    Its changes are the simplest though.
- Works in lazy and eager mode.
- ALiBI is restricted to "VLLM_PROMPT_USE_FUSEDSDPA=false", and
  "VLLM_CONTIGUOUS_PA=true".
- Add position offsets to improve quality on BS > 1 with sequences of
  varying length.
- BS > 1 may have accuracy issues if on FW < 1.19.0. This is due to
  limitation in softmax. Resolved on FW >= 1.19.0.
- NTT patch for GQA

Co-authored-by: Tanner Voas <[email protected]>
Co-authored-by: Haihao Xiang <[email protected]>
Signed-off-by: Tanner Voas <[email protected]>
  • Loading branch information
tannervoas742 and xhaihao committed Dec 16, 2024
1 parent da61ecf commit 9fac2b5
Show file tree
Hide file tree
Showing 22 changed files with 545 additions and 171 deletions.
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@0766759
2 changes: 2 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def __init__(
kv_cache_dtype: str = "auto",
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
tp_rank: Optional[int] = None,
prev_attn: Optional[torch.nn.Module] = None,
) -> None:
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
assert blocksparse_params is not None
assert alibi_slopes is None, ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down
187 changes: 164 additions & 23 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
tp_rank: Optional[int] = None,
prev_attn: Optional[torch.nn.Module] = None,
**kwargs,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
Expand All @@ -142,11 +145,42 @@ def __init__(
else ModuleFusedSDPA(HPUFusedSDPA)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
self.alibi_slopes = None
self.prompt_position_bias = None
# Set upper bound on sequence length
self.max_seq_len = int(
os.getenv(
'VLLM_PROMPT_ALIBI_MAX_SEQ_LEN',
max_seq_len,
))
# Set lower bound on sequence length
self.max_seq_len = max([
self.max_seq_len,
int(os.getenv('VLLM_PROMPT_SEQ_BUCKET_MAX', '0')),
])
self.tp_rank = tp_rank
self.prev_attn = None if prev_attn is None else prev_attn.impl
if alibi_slopes is not None:
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
if (self.prev_attn is not None
and self.prev_attn.tp_rank == self.tp_rank):
self.alibi_slopes = self.prev_attn.alibi_slopes
self.prompt_position_bias = self.prev_attn.prompt_position_bias
else:
slope_tensor_dtype = {
True: torch.float32,
False: torch.bfloat16,
}[os.getenv('VLLM_ALIBI_USE_FLOAT32_BIASES', '1').lower()
in ['1', 'true']]
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=slope_tensor_dtype)
self.alibi_slopes = alibi_slopes_tensor
# Creating the prompt_position_bias once and reusing it
# if seq_len permits.
self.prompt_position_bias = _make_prompt_alibi_bias(
alibi_slopes=self.alibi_slopes,
seq_len=self.max_seq_len,
dtype=self.alibi_slopes.dtype,
)
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand All @@ -157,6 +191,12 @@ def __init__(
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

self.use_contiguous_pa = os.environ.get('VLLM_CONTIGUOUS_PA',
'true').lower() == 'true'
if not self.use_contiguous_pa:
assert alibi_slopes is None, \
'Non-contiguous PA not supported with alibi slopes!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -230,27 +270,58 @@ def forward(
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)

if attn_metadata is None or attn_metadata.block_list is None:
if not self.prefill_use_fusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
# If we have alibi_slopes, incorporate them with
# position_bias and position_bias_offset.
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
seq_lens_tensor = attn_metadata.seq_lens_tensor
position_bias = None
position_bias_offset = None
if (self.prompt_position_bias is not None
and self.alibi_slopes is not None):
if self.max_seq_len >= max(attn_bias.size(-2),
attn_bias.size(-1)):
# Using pre-computed prompt_position_bias subset.
position_bias = self.prompt_position_bias[:, :,
-attn_bias.size(-2):,
-attn_bias.size(-1):]
else:
# For longer sequences than precomputed,
# recreate the bias. This is memory inefficient.
position_bias = _make_prompt_alibi_bias(
alibi_slopes=self.alibi_slopes,
seq_len=max(attn_bias.size(-2),
attn_bias.size(-1)),
dtype=self.alibi_slopes.dtype,
)
# If seq_lens_tensor is provided, we create a
# position_bias_offset. This offset helps handle
# sequences of varying lengths in a batch.
if seq_lens_tensor is not None:
position_bias_offset = seq_lens_tensor.unsqueeze(
1).tile(1, self.num_heads).to(
dtype=self.alibi_slopes.dtype)
position_bias_offset.mul_(
self.alibi_slopes[None, :])
position_bias_offset = position_bias_offset \
- position_bias[:, :, -1, 0]
else:
attn_bias = None
position_bias = None
position_bias_offset = None

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
position_bias=position_bias,
position_bias_offset=position_bias_offset,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
Expand Down Expand Up @@ -278,6 +349,20 @@ def forward(
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
self.position_bias = None
alibi_blocks = attn_metadata.alibi_blocks
if self.alibi_slopes is not None and alibi_blocks is not None:
if (self.prev_attn is not None
and self.prev_attn.tp_rank == self.tp_rank):
self.position_bias = self.prev_attn.position_bias
else:
# For decoding stage, compute position bias using alibi_blocks.
self.position_bias = _make_decode_alibi_bias(
alibi_blocks=alibi_blocks,
alibi_slopes=self.alibi_slopes,
dtype=self.alibi_slopes.dtype,
)

output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
Expand All @@ -288,14 +373,18 @@ def forward(
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
position_bias=self.position_bias,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
values_fetch_func=self.v_cache.fetch_from_cache,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
output = output.view(batch_size, seq_len, hidden_size)
return output

def forward_encoder_decoder(
self,
Expand Down Expand Up @@ -409,12 +498,25 @@ def forward_encoder_decoder(
return output.view(batch_size, -1, hidden_size)


def _make_alibi_bias(
def _make_prompt_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_len: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Create the ALiBi position bias tensor for prompt stage.
This tensor is reused or tiled as needed for each forward pass.
Does not scale with batch size or number of blocks.
Args:
alibi_slopes: shape = [num_heads]
seq_len: int
dtype: torch.dtype
Returns:
A per-head bias tensor of shape [1, num_heads, seq_len, seq_len].
This bias encodes positional information via ALiBi slopes.
"""
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
Expand All @@ -427,15 +529,54 @@ def _make_alibi_bias(

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
per_head_bias = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
)[:, :, :, :seq_len]
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias
# to all 32 heads in Eager mode.
per_head_bias[:, :] = bias
per_head_bias.mul_(alibi_slopes[:, None, None])

return per_head_bias


def _make_decode_alibi_bias(
alibi_blocks: torch.Tensor,
alibi_slopes: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""
Create the ALiBi position bias tensor for decode stage.
Uses stored alibi_blocks and slopes for final scaling.
Scales with number of blocks, not with batch size.
Args:
alibi_blocks: shape = [num_blocks, block_size]
alibi_slopes: shape = [num_heads]
dtype: torch.dtype
Returns:
A per-head bias tensor of shape [num_blocks, num_heads, block_size].
Each row encodes position-dependent ALiBi slopes for decoding steps.
"""
num_heads = alibi_slopes.shape[0]
per_head_bias = torch.empty(
alibi_blocks.size(0),
num_heads,
alibi_blocks.size(-1),
device=alibi_slopes.device,
dtype=dtype,
)
# NOTE(Tanner):
# .copy_ was not performing broadcasting of bias
# to all 32 heads in Eager mode.
per_head_bias[:, :] = alibi_blocks.unsqueeze(-2)
per_head_bias.mul_(alibi_slopes[None, :, None])

return per_head_bias
1 change: 1 addition & 0 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
**kwargs,
) -> None:
if blocksparse_params is not None:
raise ValueError(
Expand Down
7 changes: 5 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
logits_soft_cap: Optional[int] = 4096,
per_layer_sliding_window: Optional[int] = None,
tp_rank: Optional[int] = None,
prefix: str = "",
prev_attn: Optional[nn.Module] = None,
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
Expand Down Expand Up @@ -96,7 +98,8 @@ def __init__(
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap)
blocksparse_params, logits_soft_cap,
tp_rank=tp_rank, prev_attn=prev_attn)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata:
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
alibi_blocks: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
Loading

0 comments on commit 9fac2b5

Please sign in to comment.