From aeac04a8c937a6bb2acff386ec389080b5ae05d1 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Wed, 20 Nov 2024 00:48:35 +0000 Subject: [PATCH 1/6] [HPU]add mark_step configurable for decoder layer --- vllm/model_executor/models/llama.py | 8 +------- vllm/worker/hpu_model_runner.py | 25 +++++++++++++------------ 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1603bc18f132f..11dbbde24d67e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -315,11 +315,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - if is_hpu: - import os - self.config_hidden_layers = int( - os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -346,13 +341,12 @@ def forward( if is_hpu: import habana_frameworks.torch as htorch htorch.core.mark_step() + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, residual) - if is_hpu and i % self.config_hidden_layers == 0: - htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 92424cbf55538..b130fd9550ed4 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -277,19 +277,18 @@ def gather_list(input, indices, v): def flatten(in_list): return list(itertools.chain(*in_list)) - -def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"): - if module.__class__.__name__.endswith(suffix): - - def forward_hook(module, args, output): - htorch.core.mark_step() - return output - - module.register_forward_hook(forward_hook) +def modify_decoder_layer(module: torch.nn.Module, n=1, counter=[0], suffix="DecoderLayer"): + def forward_hook(module, args, output): + htorch.core.mark_step() + return output for child_name, child_module in module.named_children(): - modify_decoder_layer(child_module) - + if child_module.__class__.__name__.endswith(suffix): + counter[0] += 1 + if counter[0] % n == 0: + child_module.register_forward_hook(forward_hook) + else: + modify_decoder_layer(child_module, n, counter) class HpuModelAdapter: @@ -756,7 +755,9 @@ def load_model(self) -> None: elif not is_fake_hpu(): self.model = self.model.to("hpu") htcore.mark_step() - modify_decoder_layer(self.model) + decoder_hidden_layers = int( + os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) + modify_decoder_layer(self.model, decoder_hidden_layers) torch.hpu.synchronize() with HabanaMemoryProfiler() as m_wrap: From 2acfdaf17d8b1118379b136f4773c18312803fab Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Wed, 20 Nov 2024 01:16:43 +0000 Subject: [PATCH 2/6] Ruff style fix --- vllm/worker/hpu_model_runner.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b130fd9550ed4..a3e1d42768285 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -277,18 +277,27 @@ def gather_list(input, indices, v): def flatten(in_list): return list(itertools.chain(*in_list)) -def modify_decoder_layer(module: torch.nn.Module, n=1, counter=[0], suffix="DecoderLayer"): + +def modify_decoder_layer(module: torch.nn.Module, + n=1, + counter=None, + suffix="DecoderLayer"): + def forward_hook(module, args, output): htorch.core.mark_step() return output + if counter is None: + counter = [0] + for child_name, child_module in module.named_children(): if child_module.__class__.__name__.endswith(suffix): counter[0] += 1 if counter[0] % n == 0: child_module.register_forward_hook(forward_hook) else: - modify_decoder_layer(child_module, n, counter) + modify_decoder_layer(child_module, n, counter) + class HpuModelAdapter: From 9252e6067b12c90f5d338a96a29b68a7339c38d9 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Tue, 19 Nov 2024 22:29:45 -0800 Subject: [PATCH 3/6] Updated qwen and GPT-bigcode mark_step with forward-hook --- vllm/model_executor/models/gpt_bigcode.py | 6 ------ vllm/model_executor/models/qwen2.py | 3 --- vllm/worker/hpu_model_runner.py | 14 +++++++++----- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 535f74a349297..1d43f0d9e3c03 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -220,10 +220,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) - if is_hpu: - import os - self.config_hidden_layers = int( - os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -252,8 +248,6 @@ def forward( hidden_states = layer(hidden_states, kv_caches[i - self.start_layer], attn_metadata) - if is_hpu and i % self.config_hidden_layers == 0: - htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e894c6e506aff..6a18e5fea53f2 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -328,9 +328,6 @@ def forward( attn_metadata, residual, ) - if current_platform.is_hpu(): - htorch.core.mark_step() - if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index a3e1d42768285..d50a1f55de02f 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -279,9 +279,9 @@ def flatten(in_list): def modify_decoder_layer(module: torch.nn.Module, + suffix="DecoderLayer", n=1, - counter=None, - suffix="DecoderLayer"): + counter=None): def forward_hook(module, args, output): htorch.core.mark_step() @@ -296,7 +296,7 @@ def forward_hook(module, args, output): if counter[0] % n == 0: child_module.register_forward_hook(forward_hook) else: - modify_decoder_layer(child_module, n, counter) + modify_decoder_layer(child_module, suffix, n, counter) class HpuModelAdapter: @@ -764,9 +764,13 @@ def load_model(self) -> None: elif not is_fake_hpu(): self.model = self.model.to("hpu") htcore.mark_step() - decoder_hidden_layers = int( + + hidden_layer_markstep = int( os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) - modify_decoder_layer(self.model, decoder_hidden_layers) + hideen_layer_suffix = os.getenv('VLLM_CONFIG_HIDDEN_LAYERS_SUFFIX', + 'DecodeLayer') + modify_decoder_layer(self.model, hideen_layer_suffix, + hidden_layer_markstep) torch.hpu.synchronize() with HabanaMemoryProfiler() as m_wrap: From 60e7f49784a9d73b08116c7ab6ab65226a9c14e0 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Wed, 20 Nov 2024 09:53:30 -0800 Subject: [PATCH 4/6] Typo error fix and Update based on the comment --- vllm/worker/hpu_model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index d50a1f55de02f..1de6b7f7b7a16 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -765,12 +765,12 @@ def load_model(self) -> None: self.model = self.model.to("hpu") htcore.mark_step() - hidden_layer_markstep = int( + hidden_layer_markstep_interval = int( os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) hideen_layer_suffix = os.getenv('VLLM_CONFIG_HIDDEN_LAYERS_SUFFIX', - 'DecodeLayer') + 'DecoderLayer') modify_decoder_layer(self.model, hideen_layer_suffix, - hidden_layer_markstep) + hidden_layer_markstep_interval) torch.hpu.synchronize() with HabanaMemoryProfiler() as m_wrap: From 0569543f1f18b09fb8ca8d1c03bc919fe16c425e Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 21 Nov 2024 16:25:18 -0800 Subject: [PATCH 5/6] Add table mapping for hidden layer suffix --- vllm/worker/hpu_model_runner.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 1de6b7f7b7a16..da5b6df7b5932 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -278,6 +278,18 @@ def flatten(in_list): return list(itertools.chain(*in_list)) +def get_decoder_layer_suffix(model_type): + # This sets the suffix for the hidden layer name which is controlled by VLLM_CONFIG_HIDDEN_LAYERS. + # The default suffix is "DecoderLayer," which is applicable for most language models such as LLaMA, Qwen, and BART. + # If the model's decoder layer name differs from the default, it will need to be specified here. + # For example, for the GPT-BigCode model, this value should be set to "BigCodeBlock". + decoder_layer_table = { + "gpt_bigcode": "BigCodeBlock", + } + + return decoder_layer_table.get(model_type, "DecoderLayer") + + def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer", n=1, @@ -767,10 +779,10 @@ def load_model(self) -> None: hidden_layer_markstep_interval = int( os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) - hideen_layer_suffix = os.getenv('VLLM_CONFIG_HIDDEN_LAYERS_SUFFIX', - 'DecoderLayer') - modify_decoder_layer(self.model, hideen_layer_suffix, - hidden_layer_markstep_interval) + modify_decoder_layer( + self.model, + get_decoder_layer_suffix(self.model.config.model_type), + hidden_layer_markstep_interval) torch.hpu.synchronize() with HabanaMemoryProfiler() as m_wrap: From 4f56b204c779fd4562df34e28b7cb1af4d4f394f Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Thu, 21 Nov 2024 16:55:15 -0800 Subject: [PATCH 6/6] Fix ruff error --- vllm/worker/hpu_model_runner.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index da5b6df7b5932..84529846ae44f 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -279,10 +279,11 @@ def flatten(in_list): def get_decoder_layer_suffix(model_type): - # This sets the suffix for the hidden layer name which is controlled by VLLM_CONFIG_HIDDEN_LAYERS. - # The default suffix is "DecoderLayer," which is applicable for most language models such as LLaMA, Qwen, and BART. - # If the model's decoder layer name differs from the default, it will need to be specified here. - # For example, for the GPT-BigCode model, this value should be set to "BigCodeBlock". + # This sets the suffix for the hidden layer name, which is controlled by + # VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is + # applicable for most language models such as LLaMA, Qwen, and BART. If the + # model's decoder layer name differs from the default, it will need to + # be specified here. decoder_layer_table = { "gpt_bigcode": "BigCodeBlock", }