From 07bb4ab16a55445c81a0980bcfcd8accd2ba2d4b Mon Sep 17 00:00:00 2001 From: Nir David Date: Mon, 18 Nov 2024 13:47:45 +0200 Subject: [PATCH] Enable patching Fused SDPA --- vllm/attention/backends/hpu_attn.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 97410360e4c61..07aceaa0522ff 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -7,7 +7,8 @@ import torch import vllm_hpu_extension.ops as ops -from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache +from vllm_hpu_extension.utils import (Matmul, Softmax, VLLMKVCache, + ModuleFusedSDPA) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) @@ -18,6 +19,14 @@ logger = init_logger(__name__) +HPUFusedSDPA = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + HPUFusedSDPA = FusedSDPA +except ImportError: + logger.warning("Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation.") + class HPUAttentionBackend(AttentionBackend): @@ -115,6 +124,8 @@ def __init__( self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() + self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None else ModuleFusedSDPA( + HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.alibi_slopes = alibi_slopes @@ -210,6 +221,7 @@ def forward( softmax_op=self.softmax, matmul_av_op=self.matmul_av, valid_seq_lengths=attn_metadata.seq_lens_tensor, + fsdpa_op=self.fused_scaled_dot_product_attention, ) else: # TODO: enable FusedSDPA