diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 7a867e79b203d..2259630fa10b7 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -2,6 +2,7 @@ # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company ############################################################################### +import os from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type @@ -166,6 +167,12 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.prefill_usefusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] + if self.prefill_usefusedsdpa: + assert alibi_slopes is None, \ + 'Prefill with FusedSDPA not supported with alibi slopes!' + suppored_head_sizes = HabanaPagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: raise ValueError( @@ -223,15 +230,18 @@ def forward( if attn_metadata.is_prompt: # Prompt run. if kv_cache is None or attn_metadata.block_tables.numel() == 0: - # TODO: move this outside of model - assert attn_metadata.attn_bias is not None, \ - 'attn_bias must be set before calling model.forward!' - attn_bias = attn_metadata.attn_bias - if self.alibi_slopes is not None and \ - self.position_bias is not None: - attn_bias.add_(self.position_bias[:, :, - -attn_bias.size(2):, - -attn_bias.size(3):]) + if not self.prefill_usefusedsdpa: + # TODO: move this outside of model + assert attn_metadata.attn_bias is not None, \ + 'attn_bias must be set before calling model.forward!' + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None and \ + self.position_bias is not None: + attn_bias.add_(self.position_bias[:, :, + -attn_bias.size(2):, + -attn_bias.size(3):]) + else: + attn_bias = None query_shape = (batch_size, seq_len, self.num_heads, self.head_size) @@ -247,6 +257,7 @@ def forward( matmul_qk_op=self.matmul_qk, softmax_op=self.softmax, matmul_av_op=self.matmul_av, + valid_seq_lengths=attn_metadata.seq_lens_tensor, ) output = out.reshape(batch_size, seq_len, hidden_size) else: diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 23f6964723d3f..2af5634a8d1a6 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -21,6 +21,13 @@ except ImportError: logger.warning("Could not import HPU FusedRMSNorm kernel. " "vLLM will use forward_native implementation of RMSNorm.") +HPUFusedSDPA = None +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + HPUFusedSDPA = FusedSDPA +except ImportError: + logger.warning("Could not import HPU FusedSDPA kernel. " + "vLLM will use native implementation.") PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') @@ -126,6 +133,21 @@ def static_fused_moe(hidden_states, w1, w2, score, topk): return final_hidden_states.view(-1, D) +#TODO: remove after fusedsdpa fix for query_head != kv_head +def repeat_kv(kv: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The kv go from (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = kv.shape + if n_rep == 1: + return kv + kv = kv[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, + head_dim) + return kv.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -136,24 +158,36 @@ def prompt_attention( matmul_qk_op=torch.matmul, softmax_op=torch.softmax, matmul_av_op=torch.matmul, + valid_seq_lengths: Optional[torch.Tensor] = None, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) query_heads = query.size(1) kv_heads = key.size(1) - if query_heads != kv_heads: - query = query.unflatten(1, (kv_heads, -1)) - key = key.unflatten(1, (kv_heads, 1)) - value = value.unflatten(1, (kv_heads, 1)) + if attn_bias is not None or HPUFusedSDPA is None: + if query_heads != kv_heads: + query = query.unflatten(1, (kv_heads, -1)) + key = key.unflatten(1, (kv_heads, 1)) + value = value.unflatten(1, (kv_heads, 1)) + if attn_bias is not None: + attn_bias = attn_bias.unsqueeze(2) + attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) if attn_bias is not None: - attn_bias = attn_bias.unsqueeze(2) - attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) - if attn_bias is not None: - attn_weights.add_(attn_bias) - attn_weights = softmax_op(attn_weights, dim=-1) - attn_weights = matmul_av_op(attn_weights, value) - if query_heads != kv_heads: - attn_weights = attn_weights.flatten(1, 2) + attn_weights.add_(attn_bias) + attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = matmul_av_op(attn_weights, value) + if query_heads != kv_heads: + attn_weights = attn_weights.flatten(1, 2) + else: + #TODO: remove after fusedsdpa fix for query_heads != kv_heads + if query_heads != kv_heads: + key = repeat_kv(key, int(query_heads // kv_heads)) + value = repeat_kv(value, int(query_heads // kv_heads)) + softmax_mode = 'fast' + recompute_mode = True + attn_weights = FusedSDPA.apply(query, key, value, None, 0.0, True, + scale, softmax_mode, recompute_mode, + valid_seq_lengths, 'right') attn_weights = attn_weights.transpose(1, 2) return attn_weights diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 72aba42ae8553..e52b61539b540 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -151,6 +151,9 @@ class HpuModelAdapter(): def __init__(self, model, enforce_eager): self.model = model + self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', + '0').lower() in ['1', 'true'] + if not htorch.utils.internal.is_lazy() and not enforce_eager: self.model = torch.compile(self.model, backend='hpu_backend', @@ -159,7 +162,7 @@ def __init__(self, model, enforce_eager): def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): prefill_metadata = attn_metadata - if prefill_metadata is None: + if prefill_metadata is None or self.prefill_use_fusedsdpa: return attn_metadata seq_lens_t = prefill_metadata.seq_lens_tensor @@ -599,7 +602,6 @@ def _prepare_prompt( # actual prompt lens context_lens.append(context_len) query_lens.append(seq_len - context_len) - input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. @@ -672,7 +674,6 @@ def _prepare_prompt( max_prompt_len = max( find_bucket(max(seq_lens), self.prompt_seq_bucket_cfg), self.block_size) - input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, pad=0,