From 39c6b6c3f0551b69f67bda2b9c44c359f04b3c54 Mon Sep 17 00:00:00 2001 From: Marceli Fylcek Date: Mon, 25 Nov 2024 09:36:28 +0100 Subject: [PATCH 1/4] Limit decode block size (#532) Limit decode bucket size to num_hpu_blocks --- requirements-hpu.txt | 4 +--- vllm/worker/hpu_worker.py | 6 +++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/requirements-hpu.txt b/requirements-hpu.txt index 07f9c31117e49..ddf1caccf41d8 100644 --- a/requirements-hpu.txt +++ b/requirements-hpu.txt @@ -8,7 +8,5 @@ pandas tabulate setuptools>=61 setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@61334c5 +vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@ac9740d neural-compressor @ git+https://github.com/intel/neural-compressor.git@b196432 - - diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 2b8f955265792..1004af0eca40a 100644 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -166,7 +166,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: if is_fake_hpu(): cache_block_size = self.get_cache_block_size_bytes() fake_hpu_cache_alloc = 4 * 2**30 # take 4 GiB flat on fake hpu - return fake_hpu_cache_alloc // cache_block_size, 0 + num_fake_hpu_blocks = fake_hpu_cache_alloc // cache_block_size + self.model_runner.bucketing_ctx.num_hpu_blocks = num_fake_hpu_blocks + return num_fake_hpu_blocks, 0 with HabanaMemoryProfiler() as m: self.model_runner.profile_run() torch.hpu.synchronize() @@ -203,6 +205,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_hpu_blocks = max(num_hpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks + if self.model_runner.lora_manager: self.model_runner.remove_all_loras() From 5eb8b1f75b8193d486e3aefbf7ea7a49594827ed Mon Sep 17 00:00:00 2001 From: Nir David <124874956+nirda7@users.noreply.github.com> Date: Mon, 25 Nov 2024 10:41:16 +0200 Subject: [PATCH 2/4] fix marlin flag set on hpu (#540) --- vllm/model_executor/layers/quantization/fp8.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index bc2f97b4858f3..67b4d24452040 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -120,10 +120,12 @@ def __init__(self, quant_config: Fp8Config): if current_platform.is_cuda_alike(): self.cutlass_fp8_supported = cutlass_fp8_supported() - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - self.use_marlin = (not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN) + self.use_marlin = False + if not current_platform.is_hpu(): + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False From 0f513bd19367c5cd09afba3f867ec5eb1ceeeae4 Mon Sep 17 00:00:00 2001 From: Karol Damaszke Date: Tue, 26 Nov 2024 10:03:36 +0100 Subject: [PATCH 3/4] Fix profile run for multi LoRA (#549) Fixes issue with multi LoRA during `profile_run`. --- vllm/worker/hpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 7aa68d1e98abf..de6d70dff8c76 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1266,9 +1266,9 @@ def create_dummy_seq_group_metadata(self, def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers - max_batch_size, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() - max_seq_len = min(max_seq_len, - self.max_num_batched_tokens // max_batch_size) + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + max_batch_size = min(self.max_num_seqs, + self.max_num_batched_tokens // max_seq_len) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) From 7133502d4f7a46088ae6d0df0f07703d7edb3bbd Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 26 Nov 2024 11:08:50 +0200 Subject: [PATCH 4/4] fix cutlass_fp8_supported flag set on hpu --- vllm/model_executor/layers/quantization/fp8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 67b4d24452040..0c6917b0d069c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -117,6 +117,7 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + self.cutlass_fp8_supported = False if current_platform.is_cuda_alike(): self.cutlass_fp8_supported = cutlass_fp8_supported()