diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 97410360e4c61..248aafaf1266c 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -7,7 +7,8 @@ import torch import vllm_hpu_extension.ops as ops -from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache +from vllm_hpu_extension.utils import (Matmul, Softmax, VLLMKVCache, + ModuleFusedSDPA) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) @@ -18,6 +19,14 @@ logger = init_logger(__name__) +HPUFusedSDPA = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + HPUFusedSDPA = FusedSDPA +except ImportError: + logger.warning("Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation.") + class HPUAttentionBackend(AttentionBackend): @@ -115,6 +124,8 @@ def __init__( self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() + self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ + else ModuleFusedSDPA(HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.alibi_slopes = alibi_slopes @@ -210,6 +221,7 @@ def forward( softmax_op=self.softmax, matmul_av_op=self.matmul_av, valid_seq_lengths=attn_metadata.seq_lens_tensor, + fsdpa_op=self.fused_scaled_dot_product_attention, ) else: # TODO: enable FusedSDPA