From 49a11e29734707e005fac6a468cd9d51fa58fa50 Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Wed, 8 Jan 2025 23:58:24 +0800 Subject: [PATCH] Add mark_step for encoder layers (#669) This is a updated version from https://github.com/HabanaAI/vllm-fork/pull/650. Coupled with [Use FusedSDPA for MllamaVisionSdpaAttention https://github.com/HabanaAI/vllm-fork/pull/620], these two issues arising when running llama3.2 vision model can be resolved: GC fail when batchsize>1 on Gaudi3. Increased device memory consumption with Torch 2.5 compared to Torch 2.4. --------- Signed-off-by: yan ma Co-authored-by: yisonzhu --- vllm/worker/hpu_model_runner.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 96f14d5faffb0..b5cfcf23d6e83 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -135,7 +135,7 @@ def flatten(in_list): return list(itertools.chain(*in_list)) -def get_decoder_layer_suffix(model_type): +def get_target_layer_suffix_list(model_type) -> list[str]: # This sets the suffix for the hidden layer name, which is controlled by # VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is # applicable for most language models such as LLaMA, Qwen, and BART. If the @@ -145,13 +145,17 @@ def get_decoder_layer_suffix(model_type): "gpt_bigcode": "BigCodeBlock", } - return decoder_layer_table.get(model_type, "DecoderLayer") + return [ + decoder_layer_table.get(model_type, "DecoderLayer"), "EncoderLayer" + ] -def modify_decoder_layer(module: torch.nn.Module, - suffix="DecoderLayer", - n=1, - counter=None): +def modify_model_layers(module: torch.nn.Module, + suffix_list: list[str], + n=1, + counter=None): + """Currently add mark_step at the end of specified layers. + """ def forward_hook(module, args, output): htorch.core.mark_step() @@ -161,12 +165,14 @@ def forward_hook(module, args, output): counter = [0] for child_name, child_module in module.named_children(): - if child_module.__class__.__name__.endswith(suffix): + if any( + child_module.__class__.__name__.endswith(layer) + for layer in suffix_list): counter[0] += 1 if counter[0] % n == 0: child_module.register_forward_hook(forward_hook) else: - modify_decoder_layer(child_module, suffix, n, counter) + modify_model_layers(child_module, suffix_list, n, counter) def get_path_to_rope(model: torch.nn.Module): @@ -753,10 +759,11 @@ def load_model(self) -> None: hidden_layer_markstep_interval = int( os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) model_config = getattr(self.model, "config", None) - modify_decoder_layer( + modify_model_layers( self.model, - get_decoder_layer_suffix(model_config.model_type if - model_config is not None else None), + get_target_layer_suffix_list( + model_config. + model_type if model_config is not None else None), hidden_layer_markstep_interval) path_to_rope = get_path_to_rope(self.model) torch.hpu.synchronize() @@ -1969,7 +1976,7 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], This is a helper function to create the mask for lora computations. Lora Mask is needed to ensure we match the correct lora weights for the for the request. - For Prompt phase we have + For Prompt phase we have lora_mask with shape (batch_size * seq_len, max_loras * max_rank) lora_logits_mask with shape (batch_size, max_loras * max_rank) For Decode phase we have both