Skip to content

Commit

Permalink
Support loading checkpoints quantized using Autofp8
Browse files Browse the repository at this point in the history
  • Loading branch information
Yantom1 committed Sep 16, 2024
1 parent f4ac1f9 commit 23e931b
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 24 deletions.
55 changes: 54 additions & 1 deletion vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
###############################################################################
from typing import Optional
from typing import Optional, Tuple

import habana_frameworks.torch as htorch
import torch
Expand Down Expand Up @@ -291,3 +291,56 @@ def forward(self, hidden_states, w1, w2, score, topk):
final_hidden_states += current_hidden_states_static

return final_hidden_states.view(-1, D)

# fp8
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:

"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensor for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
batch_dim_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
if batch_dim_padding:
shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
output = torch.empty(shape,
device=input.device,
dtype=torch.float8_e4m3fn)
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
raise "dynamic scaled_fp8_quant not implemented for HPU"
#TODO: calculate scale to match gaudi2 240 range instead of 448
if use_per_token_if_dynamic:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale, scale_ub)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, dtype=torch.float8_e4m3fn)[0]

return output, scale
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform

if current_platform.is_hpu():
from vllm.hpu.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant

logger = init_logger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _get_scheme_from_parts(
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if torch.cuda.is_available() else True
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
Expand Down Expand Up @@ -306,7 +306,8 @@ def get_scheme(

# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
if torch.cuda.is_available():
self._check_scheme_supported(scheme.get_min_capability())

return scheme

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
self.cutlass_fp8_supported = cutlass_fp8_supported() if torch.cuda.is_available() else False

@classmethod
def get_min_capability(cls) -> int:
Expand Down
21 changes: 14 additions & 7 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
if current_platform.is_hpu():
from vllm.hpu.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant

ACTIVATION_SCHEMES = ["static", "dynamic"]

Expand Down Expand Up @@ -112,13 +115,17 @@ class Fp8LinearMethod(LinearMethodBase):

def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
if torch.cuda.is_available():
self.cutlass_fp8_supported = cutlass_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
else:
self.cutlass_fp8_supported = False
self.use_marlin = False

def create_weights(
self,
Expand Down
37 changes: 28 additions & 9 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

if current_platform.is_hpu():
import habana_frameworks.torch.utils.experimental as htexp
from vllm.hpu.ops import scaled_fp8_quant
ops.scaled_fp8_quant = scaled_fp8_quant

def cutlass_fp8_supported() -> bool:
capability = current_platform.get_device_capability()
Expand All @@ -18,7 +21,15 @@ def cutlass_fp8_supported() -> bool:
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dtype = torch.float16
device = tensor.device
if current_platform.is_hpu():
dtype = torch.bfloat16
if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
#dequant on cpu to avoid nan on gaudi2
tensor = tensor.to('cpu')

fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale
return dq_weight

Expand Down Expand Up @@ -76,7 +87,8 @@ def requantize_with_max_scale(
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()

if current_platform.is_hpu() and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max)
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
Expand Down Expand Up @@ -147,12 +159,19 @@ def apply_fp8_linear(

if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
if current_platform.is_hpu():
#hpu does not support torch._scaled_mm (SW-197036)
output = torch.ops.hpu.fp8_gemm_v2(qinput, False, weight,
False, None, input.dtype,
x_scale, weight_scale, None,
False)
else:
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return torch.narrow(output, 0, 0, input.shape[0])

else:
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@

from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers

is_hpu = current_platform.is_hpu()
from vllm.platforms import current_platform
if current_platform.is_hpu():
import habana_frameworks.torch.core as htcore


class LlamaMLP(nn.Module):
Expand Down Expand Up @@ -521,6 +522,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight)
if current_platform.is_hpu():
torch.hpu.synchronize()
htcore.mark_step()

# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
Expand Down
3 changes: 1 addition & 2 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,7 @@ def _set_gc_threshold(self) -> None:

def load_model(self) -> None:
import habana_frameworks.torch.core as htcore
if self.model_config.quantization == 'inc':
htcore.hpu_set_env()
htcore.hpu_set_env()
with HabanaMemoryProfiler() as m:
with HabanaMemoryProfiler() as m_getmodel:
self.model = get_model(
Expand Down

0 comments on commit 23e931b

Please sign in to comment.