Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add HPU support to vLLM v1 #487

Draft
wants to merge 36 commits into
base: habana_main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2191184
vLLM v1 HPU prototype
kzawora-intel Nov 12, 2024
fd77180
copy gpu model runner code, add hpugraphs support and profile run
kzawora-intel Nov 12, 2024
4dadef5
i am very much struggling
kzawora-intel Nov 13, 2024
9db1409
it's hopeless
kzawora-intel Nov 13, 2024
3b3098c
[wip] bypass prefill chunking in v1 scheduler
kzawora-intel Nov 14, 2024
c24adb5
colonoscopy
kzawora-intel Nov 14, 2024
2da069e
prefill runs, decode has deadlock, idk why
kzawora-intel Nov 14, 2024
932ce93
i'm done for today
kzawora-intel Nov 14, 2024
fc6a1c2
do better job at prefill chunking detection
kzawora-intel Nov 14, 2024
ff0ed54
mixed batch scheduling is still a problem
kzawora-intel Nov 15, 2024
50aa6b3
Merge remote-tracking branch 'origin/habana_main' into HEAD
kzawora-intel Nov 18, 2024
debec16
general hpu code rewrite
kzawora-intel Nov 18, 2024
0c1d0b6
add debug stuff, it seems like prefill is functional
kzawora-intel Nov 18, 2024
35d3e38
slight code cleanup
kzawora-intel Nov 18, 2024
491f991
remove garbage changes
kzawora-intel Nov 18, 2024
e29b84a
gsm8k now produces 69% acc on llama3.1
kzawora-intel Nov 19, 2024
27b4f32
add config not warmed up warnings
kzawora-intel Nov 19, 2024
087b5d2
add bucketinggit add -u .!
kzawora-intel Nov 19, 2024
6fdb6a9
llama3.1 now gives 81% in gsm8k without contiguous pa
kzawora-intel Nov 19, 2024
8714f9d
disable contiguous pa by default
kzawora-intel Nov 19, 2024
40ff0ac
async data copy
kzawora-intel Nov 19, 2024
28f2ac5
add split sampler optimization
kzawora-intel Nov 19, 2024
df7a1d4
add prompt batching
kzawora-intel Nov 20, 2024
623ed10
padded logits_indices and sampling + documentation
kzawora-intel Nov 20, 2024
0371c31
update docs
kzawora-intel Nov 20, 2024
d7b2a06
fix first-party random and greedy sampler for hpu
kzawora-intel Nov 20, 2024
c934e60
format.sh
kzawora-intel Nov 20, 2024
e0f4c26
add warmup w/ sampler (it doesn't work great tho)
kzawora-intel Nov 21, 2024
58c8f5d
add hpugraph check
kzawora-intel Nov 21, 2024
0c8b075
fix async engine, fix sampler corner cases
kzawora-intel Nov 22, 2024
fecedb5
Add padding-aware scheduling
kzawora-intel Nov 25, 2024
2ab1ac8
Merge remote-tracking branch 'origin/habana_main' into HEAD
kzawora-intel Nov 25, 2024
0d41073
bucketing refactor, enable contiguous pa, defrag blocks
kzawora-intel Nov 26, 2024
5645523
FreeKVCacheBlockHeapQueue bugfixes
kzawora-intel Nov 26, 2024
fd62723
[wip] add prefix caching support (it was actually really hard)
kzawora-intel Nov 26, 2024
e80f2be
fix hpugraphs and long seq corner case
kzawora-intel Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class _Backend(enum.Enum):
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
HPU_ATTN_V1 = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()
Expand Down Expand Up @@ -174,6 +175,10 @@ def _cached_get_attn_backend(
logger.info("Using HPUAttention backend.")
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
return HPUAttentionBackend
elif backend == _Backend.HPU_ATTN_V1:
logger.info("Using HPUAttentionV1 backend.")
from vllm.v1.attention.backends.hpu_attn import HPUAttentionBackendV1
return HPUAttentionBackendV1
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
Expand Down Expand Up @@ -249,6 +254,10 @@ def which_attn_to_use(head_size: int,
return _Backend.ROCM_FLASH

if current_platform.is_hpu():
if selected_backend != _Backend.HPU_ATTN and selected_backend != _Backend.HPU_ATTN_V1:
logger.info("Cannot use %s backend on HPU.", selected_backend)
if use_v1:
return _Backend.HPU_ATTN_V1
return _Backend.HPU_ATTN

if use_v1:
Expand Down
1 change: 1 addition & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,7 @@ def init_distributed_environment(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
print(distributed_init_method)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
Expand Down
361 changes: 361 additions & 0 deletions vllm/v1/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,361 @@
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################

import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.utils import is_fake_hpu

logger = init_logger(__name__)


class HPUAttentionBackendV1(AttentionBackend):

@staticmethod
def get_name() -> str:
return "hpu-attn"

@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
return HPUAttentionImpl

@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return HPUAttentionMetadata

@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dsts: torch.Tensor,
) -> None:
HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dsts: torch.Tensor,
) -> None:
HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts)


@dataclass
class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
"""Metadata for HPUAttentionbackend."""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
attn_bias: Optional[torch.Tensor]

# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]

@classmethod
def make_prefill_metadata(cls, seq_lens_tensor, num_prefills,
num_prefill_tokens, slot_mapping):
return cls(is_prompt=True,
block_list=None,
block_mapping=None,
block_usage=None,
block_indices=None,
block_offsets=None,
block_scales=None,
block_groups=None,
attn_bias=None,
num_decode_tokens=0,
context_lens_tensor=None,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=seq_lens_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
slot_mapping=slot_mapping)

@classmethod
def make_cached_prefill_metadata(cls, seq_lens_tensor, context_lens_tensor, num_prefills,
num_prefill_tokens, slot_mapping, block_list):
return cls(is_prompt=True,
block_list=block_list,
block_mapping=None,
block_usage=None,
block_indices=None,
block_offsets=None,
block_scales=None,
block_groups=None,
attn_bias=None,
num_decode_tokens=0,
context_lens_tensor=context_lens_tensor,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=seq_lens_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
slot_mapping=slot_mapping)

@classmethod
def make_decode_metadata(cls, block_list, block_usage, block_groups,
num_decode_tokens, slot_mapping):
return cls(is_prompt=False,
block_mapping=None,
block_indices=None,
block_offsets=None,
block_scales=None,
attn_bias=None,
seq_lens_tensor=None,
context_lens_tensor=None,
num_prefills=0,
num_prefill_tokens=0,
multi_modal_placeholder_index_maps=None,
block_list=block_list,
block_usage=block_usage,
block_groups=block_groups,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping)


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|

Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|

Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.

The prompts might have different lengths, while the generation tokens
always have length 1.
"""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.matmul_qk = Matmul()
self.softmax = Softmax()
self.matmul_av = Matmul()
self.batch2block_matmul = Matmul()
self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
if alibi_slopes is not None:
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'1').lower() in ['1', 'true'] \
and not is_fake_hpu()
if self.prefill_use_fusedsdpa:
assert alibi_slopes is None, \
'Prefill with FusedSDPA not supported with alibi slopes!'

suppored_head_sizes = HPUPagedAttention.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.

Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl")
batch_size, seq_len, hidden_size = query.shape
_, seq_len_kv, _ = key.shape

query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None:
key_cache, value_cache = HPUPagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, block_indices,
block_offsets)
value_cache = self.v_cache(value, value_cache, block_indices,
block_offsets)

if attn_metadata.is_prompt:
# Prompt run.
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)
if attn_metadata is None or attn_metadata.block_list is None:
if not self.prefill_use_fusedsdpa:
# TODO: move this outside of model
assert attn_metadata.attn_bias is not None, \
'attn_bias must be set before calling model.forward'
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = None

out = ops.prompt_attention(
query.view(query_shape),
key.view(kv_shape),
value.view(kv_shape),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
)
else:
# TODO: enable FusedSDPA
out = HPUPagedAttention.forward_prefix(
query=query.view(query_shape),
key=key.view(kv_shape),
value=value.view(kv_shape),
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
attn_bias=attn_metadata.attn_bias,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
softmax_op=self.softmax,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=attn_metadata.block_list,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
batch2block_matmul_op=self.batch2block_matmul,
block2batch_matmul_op=self.block2batch_matmul,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)


def _make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_len: int,
) -> torch.Tensor:
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None]

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
1, # batch size
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
Loading
Loading