From ac9296dc60542ca19777c82c382613ef41356df5 Mon Sep 17 00:00:00 2001 From: Nir David Date: Sun, 22 Dec 2024 18:45:38 +0200 Subject: [PATCH] [SW-197036] - use torch._scaled_mm with hpu --- .../layers/quantization/utils/w8a8_utils.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 8f214861e3cee..c383fec781f7a 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -149,19 +149,12 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - if current_platform.is_hpu(): - #hpu does not support torch._scaled_mm (SW-197036) - output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight, - False, None, input.dtype, - x_scale, weight_scale, None, - False) - else: - output = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) # A fix for discrepancy in scaled_mm which returns tuple # for torch < 2.5 and a single value in torch >= 2.5