-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[HPU] Add mark_step configurable for the decoder layer. #525
Changes from 3 commits
aeac04a
2acfdaf
9252e60
60e7f49
0569543
4f56b20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -278,17 +278,25 @@ def flatten(in_list): | |
return list(itertools.chain(*in_list)) | ||
|
||
|
||
def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"): | ||
if module.__class__.__name__.endswith(suffix): | ||
def modify_decoder_layer(module: torch.nn.Module, | ||
suffix="DecoderLayer", | ||
n=1, | ||
counter=None): | ||
|
||
def forward_hook(module, args, output): | ||
htorch.core.mark_step() | ||
return output | ||
def forward_hook(module, args, output): | ||
htorch.core.mark_step() | ||
return output | ||
|
||
module.register_forward_hook(forward_hook) | ||
if counter is None: | ||
counter = [0] | ||
|
||
for child_name, child_module in module.named_children(): | ||
michalkuligowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
modify_decoder_layer(child_module) | ||
if child_module.__class__.__name__.endswith(suffix): | ||
counter[0] += 1 | ||
if counter[0] % n == 0: | ||
child_module.register_forward_hook(forward_hook) | ||
else: | ||
modify_decoder_layer(child_module, suffix, n, counter) | ||
|
||
|
||
class HpuModelAdapter: | ||
|
@@ -756,7 +764,13 @@ def load_model(self) -> None: | |
elif not is_fake_hpu(): | ||
self.model = self.model.to("hpu") | ||
htcore.mark_step() | ||
modify_decoder_layer(self.model) | ||
|
||
hidden_layer_markstep = int( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe "hidden_layer_markstep_interval" would be a better name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated! |
||
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) | ||
hideen_layer_suffix = os.getenv('VLLM_CONFIG_HIDDEN_LAYERS_SUFFIX', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a new env variable please add description in README_GAUDI.md and gaudi-installation.rst |
||
'DecodeLayer') | ||
modify_decoder_layer(self.model, hideen_layer_suffix, | ||
hidden_layer_markstep) | ||
torch.hpu.synchronize() | ||
|
||
with HabanaMemoryProfiler() as m_wrap: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed qwen.py(and some other model files) also add
mark_step
previously, we can remove it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I will take a look. It seems bigcode also use this configuration parameter, but it's not DecodeLayer they need a markstep, it's something else(88 x GPTBigCodeBlock), so we will need different suffix configuration as well for different model.
Will check out further and update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jikunshang Please check the new batch. For the bigcode, I need to run it with VLLM_CONFIG_HIDDEN_LAYERS_SUFFIX="BigCodeBlock". By the way, which model was the original code changes for?