Skip to content

Commit

Permalink
Enable patching Fused SDPA
Browse files Browse the repository at this point in the history
  • Loading branch information
nirda7 committed Dec 1, 2024
1 parent 49c9efa commit e69c6cf
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import torch
import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm_hpu_extension.utils import (Matmul, Softmax, VLLMKVCache,
ModuleFusedSDPA)

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
Expand All @@ -18,6 +19,14 @@

logger = init_logger(__name__)

HPUFusedSDPA = None
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
HPUFusedSDPA = FusedSDPA
except ImportError:
logger.warning("Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation.")


class HPUAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -115,6 +124,8 @@ def __init__(
self.block2batch_matmul = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
self.fused_scaled_dot_product_attention = None if HPUFusedSDPA is None \
else ModuleFusedSDPA(HPUFusedSDPA)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.alibi_slopes = alibi_slopes
Expand Down Expand Up @@ -210,6 +221,7 @@ def forward(
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
fsdpa_op=self.fused_scaled_dot_product_attention,
)
else:
# TODO: enable FusedSDPA
Expand Down

0 comments on commit e69c6cf

Please sign in to comment.