From 23e931b188a04fe0036864ca4783b924e6953bab Mon Sep 17 00:00:00 2001 From: yan tomsinsky Date: Mon, 16 Sep 2024 15:27:58 +0300 Subject: [PATCH] Support loading checkpoints quantized using Autofp8 --- vllm/hpu/ops.py | 55 ++++++++++++++++++- .../layers/fused_moe/fused_moe.py | 5 ++ .../compressed_tensors/compressed_tensors.py | 5 +- .../schemes/compressed_tensors_w8a8_fp8.py | 2 +- .../model_executor/layers/quantization/fp8.py | 21 ++++--- .../layers/quantization/utils/w8a8_utils.py | 37 ++++++++++--- vllm/model_executor/models/llama.py | 6 +- vllm/worker/habana_model_runner.py | 3 +- 8 files changed, 110 insertions(+), 24 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 939d195a12b08..323a33e9fa2a7 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. ############################################################################### -from typing import Optional +from typing import Optional, Tuple import habana_frameworks.torch as htorch import torch @@ -291,3 +291,56 @@ def forward(self, hidden_states, w1, w2, score, topk): final_hidden_states += current_hidden_states_static return final_hidden_states.view(-1, D) + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + batch_dim_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensor for downstream kernels that + will benefit from padding. + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + batch_dim_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + if batch_dim_padding: + shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:]) + output = torch.empty(shape, + device=input.device, + dtype=torch.float8_e4m3fn) + else: + output = torch.empty_like(input, dtype=torch.float8_e4m3fn) + if scale is None: + raise "dynamic scaled_fp8_quant not implemented for HPU" + #TODO: calculate scale to match gaudi2 240 range instead of 448 + if use_per_token_if_dynamic: + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + torch.ops._C.dynamic_per_token_scaled_fp8_quant( + output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) + else: + output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, dtype=torch.float8_e4m3fn)[0] + + return output, scale \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 413c0b6d0924e..3682362c5a864 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -11,6 +11,11 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.platforms import current_platform + +if current_platform.is_hpu(): + from vllm.hpu.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 39d00bd5733ff..badb29af1f5f6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -233,7 +233,7 @@ def _get_scheme_from_parts( if is_activation_quantization_format(self.quant_format): if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( - CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if torch.cuda.is_available() else True if is_fp8_w8a8_supported: return CompressedTensorsW8A8Fp8( strategy=weight_quant.strategy, @@ -306,7 +306,8 @@ def get_scheme( # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) - self._check_scheme_supported(scheme.get_min_capability()) + if torch.cuda.is_available(): + self._check_scheme_supported(scheme.get_min_capability()) return scheme diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index cc9d71db140c2..631774994b5c0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -21,7 +21,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme - self.cutlass_fp8_supported = cutlass_fp8_supported() + self.cutlass_fp8_supported = cutlass_fp8_supported() if torch.cuda.is_available() else False @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c829cb836ee4c..f3e304ce141c6 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,6 +23,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import print_warning_once +if current_platform.is_hpu(): + from vllm.hpu.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -112,13 +115,17 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config - self.cutlass_fp8_supported = cutlass_fp8_supported() - - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 + if torch.cuda.is_available(): + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + self.use_marlin = capability < 89 + else: + self.cutlass_fp8_supported = False + self.use_marlin = False def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 20100c76bd690..8904c9fa1789e 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,7 +6,10 @@ from vllm import _custom_ops as ops from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform - +if current_platform.is_hpu(): + import habana_frameworks.torch.utils.experimental as htexp + from vllm.hpu.ops import scaled_fp8_quant + ops.scaled_fp8_quant = scaled_fp8_quant def cutlass_fp8_supported() -> bool: capability = current_platform.get_device_capability() @@ -18,7 +21,15 @@ def cutlass_fp8_supported() -> bool: def per_tensor_dequantize( tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: - fake_qweight = tensor.to(torch.float16) + dtype = torch.float16 + device = tensor.device + if current_platform.is_hpu(): + dtype = torch.bfloat16 + if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + #dequant on cpu to avoid nan on gaudi2 + tensor = tensor.to('cpu') + + fake_qweight = tensor.to(dtype).to(device) dq_weight = fake_qweight * inv_scale return dq_weight @@ -76,7 +87,8 @@ def requantize_with_max_scale( logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() - + if current_platform.is_hpu() and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2: + max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max) # QKV / MLP is fused in the on disk checkpoint if any of the # weight scales are still set to the default since we initialize # N weight scales for N shards but we only load 1 weight scale @@ -147,12 +159,19 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) + if current_platform.is_hpu(): + #hpu does not support torch._scaled_mm (SW-197036) + output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight, + False, None, input.dtype, + x_scale, weight_scale, None, + False) + else: + output, _ = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) return torch.narrow(output, 0, 0, input.shape[0]) else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 51716b12513d8..8ccefe7be33f5 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -54,8 +54,9 @@ from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers - -is_hpu = current_platform.is_hpu() +from vllm.platforms import current_platform +if current_platform.is_hpu(): + import habana_frameworks.torch.core as htcore class LlamaMLP(nn.Module): @@ -521,6 +522,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) if current_platform.is_hpu(): torch.hpu.synchronize() + htcore.mark_step() # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index ce3848ae0a6da..b0b9114ac2d0a 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -562,8 +562,7 @@ def _set_gc_threshold(self) -> None: def load_model(self) -> None: import habana_frameworks.torch.core as htcore - if self.model_config.quantization == 'inc': - htcore.hpu_set_env() + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model(