Skip to content

Commit

Permalink
Fix fusedSDPA fp8 70B issue
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi committed Jan 10, 2025
1 parent 67a2923 commit a2d3007
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,12 @@ def prompt_fsdpa(
query_heads = query.size(1)
kv_heads = key.size(1)
VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE = os.environ.get(
'VLLM_REMOVE_REPEAT_KV_CACHE_MERGED_PREFILL', '1') == '1'
'VLLM_REMOVE_REPEAT_KV_CACHE', '1') == '1'
# TODO: remove after fusedsdpa fix for query_heads != kv_heads
if query_heads != kv_heads:
if VLLM_DO_NOT_REMOVE_REPEAT_KV_CACHE:
key = ops.repeat_kv(key, int(query_heads // kv_heads))
value = ops.repeat_kv(value, int(query_heads // kv_heads))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(1)
softmax_mode = 'fast'
recompute_mode = True
attn_weights = fsdpa_op(query, key, value, attn_bias, 0.0, False,
Expand Down

0 comments on commit a2d3007

Please sign in to comment.