Skip to content

Commit

Permalink
Add mark_step for encoder layers (#669)
Browse files Browse the repository at this point in the history
This is a updated version from
#650.


Coupled with [Use FusedSDPA for MllamaVisionSdpaAttention
#620], these two issues
arising when running llama3.2 vision model can be resolved:

GC fail when batchsize>1 on Gaudi3.
Increased device memory consumption with Torch 2.5 compared to Torch
2.4.

---------

Signed-off-by: yan ma <[email protected]>
Co-authored-by: yisonzhu <[email protected]>
  • Loading branch information
yma11 and yisonzhu authored Jan 8, 2025
1 parent 8f53dee commit 49a11e2
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def flatten(in_list):
return list(itertools.chain(*in_list))


def get_decoder_layer_suffix(model_type):
def get_target_layer_suffix_list(model_type) -> list[str]:
# This sets the suffix for the hidden layer name, which is controlled by
# VLLM_CONFIG_HIDDEN_LAYERS. The default suffix is "DecoderLayer," which is
# applicable for most language models such as LLaMA, Qwen, and BART. If the
Expand All @@ -145,13 +145,17 @@ def get_decoder_layer_suffix(model_type):
"gpt_bigcode": "BigCodeBlock",
}

return decoder_layer_table.get(model_type, "DecoderLayer")
return [
decoder_layer_table.get(model_type, "DecoderLayer"), "EncoderLayer"
]


def modify_decoder_layer(module: torch.nn.Module,
suffix="DecoderLayer",
n=1,
counter=None):
def modify_model_layers(module: torch.nn.Module,
suffix_list: list[str],
n=1,
counter=None):
"""Currently add mark_step at the end of specified layers.
"""

def forward_hook(module, args, output):
htorch.core.mark_step()
Expand All @@ -161,12 +165,14 @@ def forward_hook(module, args, output):
counter = [0]

for child_name, child_module in module.named_children():
if child_module.__class__.__name__.endswith(suffix):
if any(
child_module.__class__.__name__.endswith(layer)
for layer in suffix_list):
counter[0] += 1
if counter[0] % n == 0:
child_module.register_forward_hook(forward_hook)
else:
modify_decoder_layer(child_module, suffix, n, counter)
modify_model_layers(child_module, suffix_list, n, counter)


def get_path_to_rope(model: torch.nn.Module):
Expand Down Expand Up @@ -753,10 +759,11 @@ def load_model(self) -> None:
hidden_layer_markstep_interval = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
model_config = getattr(self.model, "config", None)
modify_decoder_layer(
modify_model_layers(
self.model,
get_decoder_layer_suffix(model_config.model_type if
model_config is not None else None),
get_target_layer_suffix_list(
model_config.
model_type if model_config is not None else None),
hidden_layer_markstep_interval)
path_to_rope = get_path_to_rope(self.model)
torch.hpu.synchronize()
Expand Down Expand Up @@ -1969,7 +1976,7 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],
This is a helper function to create the mask for lora computations.
Lora Mask is needed to ensure we match the correct lora weights for the
for the request.
For Prompt phase we have
For Prompt phase we have
lora_mask with shape (batch_size * seq_len, max_loras * max_rank)
lora_logits_mask with shape (batch_size, max_loras * max_rank)
For Decode phase we have both
Expand Down

0 comments on commit 49a11e2

Please sign in to comment.