From 33b2b1fa81d00fee61223a2c0f007e90579281d5 Mon Sep 17 00:00:00 2001 From: Nir David Date: Sun, 11 Aug 2024 12:00:09 +0300 Subject: [PATCH] Inc on vLLM - Fix CR comments --- vllm/config.py | 4 ++-- vllm/engine/arg_utils.py | 4 ++-- vllm/model_executor/models/llama.py | 7 ++++--- vllm/utils.py | 2 +- vllm/worker/cache_engine.py | 2 +- vllm/worker/habana_model_runner.py | 4 +++- vllm/worker/habana_worker.py | 3 --- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ec7e8fed30fdb..80167d1c872ef 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -474,13 +474,13 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "hf8"): + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " "scaling factor. " - "FP8_E4M3 is also supported on hpu (hf8).") + "Intel Gaudi (HPU) also supports fp8 (using fp8_inc).") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 29160143ef469..1b07aa3540f7b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -229,12 +229,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'hf8'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3). ' - 'FP8_E4M3 is also supported on hpu (hf8).') + 'Intel Gaudi (HPU) also supports fp8 (using fp8_inc).') parser.add_argument( '--quantization-param-path', type=nullable_str, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b71a4ee7e3b9d..d48cc03035d11 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -49,7 +49,8 @@ default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import is_hip, is_hpu +from vllm.utils import is_hip +from vllm.platforms import current_platform from .interfaces import SupportsLoRA from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -317,7 +318,7 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - if is_hpu(): + if current_platform.is_hpu(): import habana_frameworks.torch as htorch htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): @@ -329,7 +330,7 @@ def forward( attn_metadata, residual, ) - if is_hpu(): + if current_platform.is_hpu(): htorch.core.mark_step() if not get_pp_group().is_last_rank: diff --git a/vllm/utils.py b/vllm/utils.py index 63821960739f1..323558e235f04 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -39,7 +39,7 @@ "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, - "hf8": torch.float8_e4m3fn, + "fp8_inc": torch.float8_e4m3fn, } TORCH_DTYPE_TO_NUMPY_DTYPE = { diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 8e41cbfd511ff..2707e7eb92f7b 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -91,7 +91,7 @@ def _allocate_kv_cache( # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. - dtype = torch.int8 if self.dtype == torch.float8_e4m3fn else self.dtype + dtype = torch.uint8 if self.dtype == torch.float8_e4m3fn else self.dtype kv_cache.append( torch.zeros(kv_cache_shape, dtype=dtype, diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 87b0b83bb74ea..7fc4fe6d5a847 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -413,6 +413,9 @@ def __init__( self._setup_buckets() def load_model(self) -> None: + if self.model_config.quantization == 'inc': + import habana_frameworks.torch.core as htcore + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model( @@ -429,7 +432,6 @@ def load_model(self) -> None: f"took {m_getmodel.get_summary_string()}") logger.info(msg) - import habana_frameworks.torch.core as htcore if self.model_config.quantization == 'inc': logger.info("Preparing model with INC..") with HabanaMemoryProfiler() as m_inc: diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index ffebfbfeda1bc..ece7b59e15e10 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -114,9 +114,6 @@ def init_device(self) -> None: set_random_seed(self.model_config.seed) def load_model(self): - if self.model_config.quantization == 'inc': - import habana_frameworks.torch.core as htcore - htcore.hpu_set_env() self.model_runner.load_model() @torch.inference_mode()