Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HPU] Add mark_step configurable for the decoder layer. #525

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
if is_hpu:
import os
self.config_hidden_layers = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
Expand Down Expand Up @@ -252,8 +248,6 @@ def forward(
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if is_hpu and i % self.config_hidden_layers == 0:
htorch.core.mark_step()
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
Expand Down
8 changes: 1 addition & 7 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

if is_hpu:
import os
self.config_hidden_layers = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

Expand All @@ -346,13 +341,12 @@ def forward(
if is_hpu:
import habana_frameworks.torch as htorch
htorch.core.mark_step()

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if is_hpu and i % self.config_hidden_layers == 0:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed qwen.py(and some other model files) also add mark_step previously, we can remove it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will take a look. It seems bigcode also use this configuration parameter, but it's not DecodeLayer they need a markstep, it's something else(88 x GPTBigCodeBlock), so we will need different suffix configuration as well for different model.

Will check out further and update.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jikunshang Please check the new batch. For the bigcode, I need to run it with VLLM_CONFIG_HIDDEN_LAYERS_SUFFIX="BigCodeBlock". By the way, which model was the original code changes for?

htorch.core.mark_step()
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
Expand Down
3 changes: 0 additions & 3 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,6 @@ def forward(
attn_metadata,
residual,
)
if current_platform.is_hpu():
htorch.core.mark_step()

if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
Expand Down
30 changes: 22 additions & 8 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,17 +278,25 @@ def flatten(in_list):
return list(itertools.chain(*in_list))


def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"):
if module.__class__.__name__.endswith(suffix):
def modify_decoder_layer(module: torch.nn.Module,
suffix="DecoderLayer",
n=1,
counter=None):

def forward_hook(module, args, output):
htorch.core.mark_step()
return output
def forward_hook(module, args, output):
htorch.core.mark_step()
return output

module.register_forward_hook(forward_hook)
if counter is None:
counter = [0]

for child_name, child_module in module.named_children():
michalkuligowski marked this conversation as resolved.
Show resolved Hide resolved
modify_decoder_layer(child_module)
if child_module.__class__.__name__.endswith(suffix):
counter[0] += 1
if counter[0] % n == 0:
child_module.register_forward_hook(forward_hook)
else:
modify_decoder_layer(child_module, suffix, n, counter)


class HpuModelAdapter:
Expand Down Expand Up @@ -756,7 +764,13 @@ def load_model(self) -> None:
elif not is_fake_hpu():
self.model = self.model.to("hpu")
htcore.mark_step()
modify_decoder_layer(self.model)

hidden_layer_markstep = int(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "hidden_layer_markstep_interval" would be a better name?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
hideen_layer_suffix = os.getenv('VLLM_CONFIG_HIDDEN_LAYERS_SUFFIX',

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new env variable please add description in README_GAUDI.md and gaudi-installation.rst

'DecodeLayer')
modify_decoder_layer(self.model, hideen_layer_suffix,
hidden_layer_markstep)
torch.hpu.synchronize()

with HabanaMemoryProfiler() as m_wrap:
Expand Down
Loading