Skip to content

Commit

Permalink
[SW-197036] - use torch._scaled_mm with hpu (#660)
Browse files Browse the repository at this point in the history
Remove WA using torch.ops.hpu.fp8_gemm_v2 for hpu.
  • Loading branch information
nirda7 authored Jan 9, 2025
1 parent fa9dbf2 commit 73aaf71
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,19 +149,12 @@ def apply_fp8_linear(

if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
if current_platform.is_hpu():
#hpu does not support torch._scaled_mm (SW-197036)
output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight,
False, None, input.dtype,
x_scale, weight_scale, None,
False)
else:
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)

# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
Expand Down

0 comments on commit 73aaf71

Please sign in to comment.