Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HPU] Add mark_step configurable for the decoder layer. #525

Merged
merged 6 commits into from
Nov 26, 2024

Conversation

jiminha
Copy link

@jiminha jiminha commented Nov 20, 2024

We are seeing 10% performance regression in the llama-based model due to vllm-project#10239. The mark_step() function needs to be configured differently for each model to achieve the best performance. For some models, mark_step() for every decoder step would be optimal, but for other models, it's better to run it every n-th step. We are adding a counter to only register the hook for every n-th step, which can be configured with VLLM_CONFIG_HIDDEN_LAYERS

@jiminha jiminha requested a review from libinta November 20, 2024 00:55
@zhouyuan
Copy link

CC @jikunshang

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if is_hpu and i % self.config_hidden_layers == 0:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed qwen.py(and some other model files) also add mark_step previously, we can remove it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will take a look. It seems bigcode also use this configuration parameter, but it's not DecodeLayer they need a markstep, it's something else(88 x GPTBigCodeBlock), so we will need different suffix configuration as well for different model.

Will check out further and update.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jikunshang Please check the new batch. For the bigcode, I need to run it with VLLM_CONFIG_HIDDEN_LAYERS_SUFFIX="BigCodeBlock". By the way, which model was the original code changes for?

vllm/worker/hpu_model_runner.py Show resolved Hide resolved
@@ -756,7 +764,13 @@ def load_model(self) -> None:
elif not is_fake_hpu():
self.model = self.model.to("hpu")
htcore.mark_step()
modify_decoder_layer(self.model)

hidden_layer_markstep = int(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "hidden_layer_markstep_interval" would be a better name?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!


hidden_layer_markstep_interval = int(
os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
hideen_layer_suffix = os.getenv('VLLM_CONFIG_HIDDEN_LAYERS_SUFFIX',

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new env variable please add description in README_GAUDI.md and gaudi-installation.rst

@michalkuligowski michalkuligowski merged commit b62f1b2 into habana_main Nov 26, 2024
9 checks passed
@michalkuligowski michalkuligowski deleted the jha/markstep_config branch November 26, 2024 10:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants