Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: selecting correct backend for MultiHeadAttention #645

Open
wants to merge 9 commits into
base: habana_main
Choose a base branch
from
35 changes: 31 additions & 4 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,16 @@ def __init__(
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS

self.attn_backend = attn_backend if attn_backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
attn_backend_enum = backend_name_to_enum(attn_backend.get_name())

if attn_backend_enum in {
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
}:
attn_backend_enum = _Backend.XFORMERS

self.attn_backend = attn_backend_enum if attn_backend_enum in {
_Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.HPU_ATTN
} else _Backend.TORCH_SDPA

def forward(
Expand Down Expand Up @@ -228,6 +233,28 @@ def forward(
value,
scale=self.scale)
out = out.transpose(1, 2)
elif self.attn_backend == _Backend.HPU_ATTN:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA

fsdpa_op = ModuleFusedSDPA(FusedSDPA)

query, key, value = (x.transpose(1, 2)
for x in (query, key, value))

out = fsdpa_op(query,
key,
value,
None,
dropout_p=0.0,
is_causal=True,
adobrzyniewicz-habana marked this conversation as resolved.
Show resolved Hide resolved
scale=self.scale,
softmax_mode="fast",
recompute_mode=True,
valid_sequence_lengths=None)

out = out.transpose(1, 2).contiguous()

return out.view(bsz, q_len, -1)


Expand Down
Loading