Skip to content

Commit

Permalink
ruff fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Yantom1 committed Sep 16, 2024
1 parent 23e931b commit 363de3c
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 11 deletions.
5 changes: 3 additions & 2 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def scaled_fp8_quant(
else:
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
if scale is None:
raise "dynamic scaled_fp8_quant not implemented for HPU"
raise RuntimeError("dynamic scaled_fp8_quant not implemented for HPU")
#TODO: calculate scale to match gaudi2 240 range instead of 448
if use_per_token_if_dynamic:
scale = torch.empty((input.numel() // input.shape[-1], 1),
Expand All @@ -341,6 +341,7 @@ def scaled_fp8_quant(
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False, dtype=torch.float8_e4m3fn)[0]
output = torch.ops.hpu.cast_to_fp8_v2(input, 1/scale, False, False,
dtype=torch.float8_e4m3fn)[0]

return output, scale
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def _get_scheme_from_parts(
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False) if torch.cuda.is_available() else True
CompressedTensorsW8A8Fp8.get_min_capability(), error=False) \
if torch.cuda.is_available() else True
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported() if torch.cuda.is_available() else False
self.cutlass_fp8_supported = cutlass_fp8_supported() \
if torch.cuda.is_available() else False

@classmethod
def get_min_capability(cls) -> int:
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def __init__(self, quant_config: Fp8Config):
if torch.cuda.is_available():
self.cutlass_fp8_supported = cutlass_fp8_supported()

# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def requantize_with_max_scale(
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
if current_platform.is_hpu() and htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
max_w_scale = max_w_scale * (torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max)
if current_platform.is_hpu() and \
htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
max_w_scale = max_w_scale * \
(torch.finfo(torch.float8_e4m3fn).max/torch.finfo(torch.float8_e4m3fnuz).max)
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_hip

Expand Down Expand Up @@ -321,7 +320,7 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

if is_hpu:
if current_platform.is_hpu():
import habana_frameworks.torch as htorch
htorch.core.mark_step()
for i in range(self.start_layer, self.end_layer):
Expand All @@ -333,7 +332,7 @@ def forward(
attn_metadata,
residual,
)
if is_hpu:
if current_platform.is_hpu():
htorch.core.mark_step()

if not get_pp_group().is_last_rank:
Expand Down

0 comments on commit 363de3c

Please sign in to comment.