diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 535f74a349297..1d43f0d9e3c03 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -220,10 +220,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], config.n_embd)) - if is_hpu: - import os - self.config_hidden_layers = int( - os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -252,8 +248,6 @@ def forward( hidden_states = layer(hidden_states, kv_caches[i - self.start_layer], attn_metadata) - if is_hpu and i % self.config_hidden_layers == 0: - htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 1603bc18f132f..11dbbde24d67e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -315,11 +315,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - if is_hpu: - import os - self.config_hidden_layers = int( - os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -346,13 +341,12 @@ def forward( if is_hpu: import habana_frameworks.torch as htorch htorch.core.mark_step() + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, residual) - if is_hpu and i % self.config_hidden_layers == 0: - htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e894c6e506aff..6a18e5fea53f2 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -328,9 +328,6 @@ def forward( attn_metadata, residual, ) - if current_platform.is_hpu(): - htorch.core.mark_step() - if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 92424cbf55538..84529846ae44f 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -278,17 +278,38 @@ def flatten(in_list): return list(itertools.chain(*in_list)) -def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"): - if module.__class__.__name__.endswith(suffix): +def get_decoder_layer_suffix(model_type): + # This sets the suffix for the hidden layer name, which is controlled by + # VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is + # applicable for most language models such as LLaMA, Qwen, and BART. If the + # model's decoder layer name differs from the default, it will need to + # be specified here. + decoder_layer_table = { + "gpt_bigcode": "BigCodeBlock", + } - def forward_hook(module, args, output): - htorch.core.mark_step() - return output + return decoder_layer_table.get(model_type, "DecoderLayer") + + +def modify_decoder_layer(module: torch.nn.Module, + suffix="DecoderLayer", + n=1, + counter=None): - module.register_forward_hook(forward_hook) + def forward_hook(module, args, output): + htorch.core.mark_step() + return output + + if counter is None: + counter = [0] for child_name, child_module in module.named_children(): - modify_decoder_layer(child_module) + if child_module.__class__.__name__.endswith(suffix): + counter[0] += 1 + if counter[0] % n == 0: + child_module.register_forward_hook(forward_hook) + else: + modify_decoder_layer(child_module, suffix, n, counter) class HpuModelAdapter: @@ -756,7 +777,13 @@ def load_model(self) -> None: elif not is_fake_hpu(): self.model = self.model.to("hpu") htcore.mark_step() - modify_decoder_layer(self.model) + + hidden_layer_markstep_interval = int( + os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) + modify_decoder_layer( + self.model, + get_decoder_layer_suffix(self.model.config.model_type), + hidden_layer_markstep_interval) torch.hpu.synchronize() with HabanaMemoryProfiler() as m_wrap: