Skip to content

Commit

Permalink
Extend test_paged_attention to support HPU
Browse files Browse the repository at this point in the history
  • Loading branch information
itaraban committed Sep 10, 2024
1 parent 5cf8441 commit 60dfdff
Showing 1 changed file with 48 additions and 10 deletions.
58 changes: 48 additions & 10 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@

import pytest
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm.utils import get_max_shared_memory_bytes, is_hip, is_hpu

if is_hpu():
import vllm.hpu.ops as hpu_ops
from vllm.attention.backends.habana_attn import _make_alibi_bias
else:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

from .allclose_default import get_default_atol, get_default_rtol

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
MAX_SEQ_LEN = get_max_shared_memory_bytes(
) // FLOAT32_BYTES - 512 if not is_hpu() else 128
# There may not be enough gpu memory due to large NUM_BLOCKS.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing
Expand All @@ -35,9 +41,12 @@
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
if is_hpu():
DEVICES = ["hpu"]
else:
DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


def ref_masked_attention(
Expand Down Expand Up @@ -120,7 +129,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
def test_paged_attention(
kv_cache_factory,
version: str,
Expand All @@ -134,12 +143,20 @@ def test_paged_attention(
seed: int,
device: str,
) -> None:
if is_hpu():
if version != "v1":
pytest.skip("Paged attention v2 not supported on HPU")
if kv_cache_dtype != "auto":
pytest.skip("Only auto kv_cache_dtype supported on HPU")

if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if is_hpu():
torch.hpu.manual_seed(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
Expand Down Expand Up @@ -181,7 +198,26 @@ def test_paged_attention(

# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
if device == "hpu":
key_cache_hpu = key_cache.permute((0, 3, 1, 2, 4)).flatten(3)
value_cache_hpu = value_cache.permute((0, 3, 1, 2))
position_bias = None
if alibi_slopes is not None:
position_bias = _make_alibi_bias(alibi_slopes, num_kv_heads,
alibi_slopes.dtype, max_seq_len)
output = hpu_ops.paged_attention_v1(
query,
key_cache_hpu,
value_cache_hpu,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
position_bias,
kv_cache_dtype,
)
elif version == "v1":
ops.paged_attention_v1(
output,
query,
Expand Down Expand Up @@ -318,7 +354,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
Expand All @@ -328,6 +364,8 @@ def test_multi_query_kv_attention(
seed: int,
device: str,
) -> None:
if is_hpu():
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
Expand Down

0 comments on commit 60dfdff

Please sign in to comment.