From a2d300769fef5b1927132e3cb733ebaf0d218e59 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Fri, 10 Jan 2025 20:54:15 +0000 Subject: [PATCH] Fix fusedSDPA fp8 70B issue Signed-off-by: Chendi.Xue --- vllm/attention/backends/hpu_attn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 1d5b83c1e61f2..e6c2925beb1fe 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -49,14 +49,12 @@ def prompt_fsdpa( query_heads = query.size(1) kv_heads = key.size(1) VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get( - 'VLLM_REMOVE_REPEAT_KV_CACHE_MERGED_PREFILL', '1') == '1' + 'VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1' # TODO: remove after fusedsdpa fix for query_heads != kv_heads if query_heads != kv_heads: if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE: key = ops.repeat_kv(key, int(query_heads // kv_heads)) value = ops.repeat_kv(value, int(query_heads // kv_heads)) - if attn_bias is not None: - attn_bias = attn_bias.unsqueeze(1) softmax_mode = 'fast' recompute_mode = True attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False,