From dc986f6a8935be9fce91f4cc9e6c288c6d9dab01 Mon Sep 17 00:00:00 2001 From: Nir David Date: Mon, 18 Nov 2024 13:47:45 +0200 Subject: [PATCH] Enable patching Fused SDPA --- requirements-hpu.txt | 2 +- vllm/attention/backends/hpu_attn.py | 14 +++++++++++++- vllm/executor/hpu_executor.py | 3 ++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 3586e70406d72..eb6dc55c5be48 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,5 +8,5 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@070591a +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@41ff369 neural-compressor @ git+https://github.com/intel/neural-compressor.git@b196432 diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index dfdddbb67d122..5365507c6eb5f 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -8,7 +8,8 @@ import torch import vllm_hpu_extension.ops as ops -from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache +from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax, + VLLMKVCache) from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) @@ -20,6 +21,14 @@ logger = init_logger(__name__) +HPUFusedSDPA = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + HPUFusedSDPA = FusedSDPA +except ImportError: + logger.warning("Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation.") + class HPUAttentionBackend(AttentionBackend): @@ -117,6 +126,8 @@ def __init__( self.block2batch_matmul = Matmul() self.k_cache = VLLMKVCache() self.v_cache = VLLMKVCache() + self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \ + else ModuleFusedSDPA(HPUFusedSDPA) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.alibi_slopes = alibi_slopes @@ -222,6 +233,7 @@ def forward( softmax_op=self.softmax, matmul_av_op=self.matmul_av, valid_seq_lengths=attn_metadata.seq_lens_tensor, + fsdpa_op=self.fused_scaled_dot_product_attention, ) else: # TODO: enable FusedSDPA diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py index e82cc10d0e9f0..27bb3f9d33bd8 100644 --- a/vllm/executor/hpu_executor.py +++ b/vllm/executor/hpu_executor.py @@ -201,7 +201,8 @@ def stop_profile(self) -> None: self.driver_worker.stop_profile() def shutdown(self) -> None: - if hasattr(self.driver_worker, 'shutdown_inc'): + if hasattr(self, "driver_worker") and hasattr(self.driver_worker, + 'shutdown_inc'): self.driver_worker.shutdown_inc()