From 60dfdffecfa91d0c816f729504b13729c11678b5 Mon Sep 17 00:00:00 2001 From: Ilia Taraban Date: Sun, 8 Sep 2024 22:29:28 +0200 Subject: [PATCH] Extend test_paged_attention to support HPU --- tests/kernels/test_attention.py | 58 +++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index c7c6707461c3e..48b6700957205 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -3,18 +3,24 @@ import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm import _custom_ops as ops -from vllm.utils import get_max_shared_memory_bytes, is_hip +from vllm.utils import get_max_shared_memory_bytes, is_hip, is_hpu + +if is_hpu(): + import vllm.hpu.ops as hpu_ops + from vllm.attention.backends.habana_attn import _make_alibi_bias +else: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from .allclose_default import get_default_atol, get_default_rtol FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +MAX_SEQ_LEN = get_max_shared_memory_bytes( +) // FLOAT32_BYTES - 512 if not is_hpu() else 128 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing @@ -35,9 +41,12 @@ USE_ALIBI = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] +if is_hpu(): + DEVICES = ["hpu"] +else: + DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) + ] def ref_masked_attention( @@ -120,7 +129,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) def test_paged_attention( kv_cache_factory, version: str, @@ -134,12 +143,20 @@ def test_paged_attention( seed: int, device: str, ) -> None: + if is_hpu(): + if version != "v1": + pytest.skip("Paged attention v2 not supported on HPU") + if kv_cache_dtype != "auto": + pytest.skip("Only auto kv_cache_dtype supported on HPU") + if kv_cache_dtype == "fp8" and head_size % 16: pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) + if is_hpu(): + torch.hpu.manual_seed(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads @@ -181,7 +198,26 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) - if version == "v1": + if device == "hpu": + key_cache_hpu = key_cache.permute((0, 3, 1, 2, 4)).flatten(3) + value_cache_hpu = value_cache.permute((0, 3, 1, 2)) + position_bias = None + if alibi_slopes is not None: + position_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, + alibi_slopes.dtype, max_seq_len) + output = hpu_ops.paged_attention_v1( + query, + key_cache_hpu, + value_cache_hpu, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + position_bias, + kv_cache_dtype, + ) + elif version == "v1": ops.paged_attention_v1( output, query, @@ -318,7 +354,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -328,6 +364,8 @@ def test_multi_query_kv_attention( seed: int, device: str, ) -> None: + if is_hpu(): + pytest.skip() random.seed(seed) torch.random.manual_seed(seed) if torch.cuda.is_available():