Skip to content

Commit

Permalink
[BugFix][Habana_main][Multistep]Fix multistep deepcopy overhead (#452)
Browse files Browse the repository at this point in the history
  • Loading branch information
xuechendi authored Nov 6, 2024
1 parent 1033c3e commit c3c0e90
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,6 +2115,19 @@ def execute_model(
# we only want to pythonize in the last step
sampling_metadata.skip_sampler_cpu_output = True
self.model.model.sampler.include_gpu_probs_tensor = True
cache_orig_output_tokens_len: List[Dict] = []

def try_revert_dummy_output_tokens():
if len(cache_orig_output_tokens_len) > 0:
# Reuse the original output token ids length
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
data.output_token_ids = \
data.output_token_ids[:orig_output_tokens_len]

for i in range(num_steps):
if i != 0 and not self.is_driver_worker:
broadcast_data = broadcast_tensor_dict(src=0)
Expand Down Expand Up @@ -2175,17 +2188,22 @@ def execute_model(
htorch.core.mark_step()
if i < num_steps - 1:
if i == 0:
import copy
ctx = model_input.async_callback.keywords[ # type: ignore
"ctx"]
seq_group_metadata_list = ctx.seq_group_metadata_list
seq_group_metadata_list = copy.deepcopy(
seq_group_metadata_list)
# Cache the original output token ids
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[i][j] = \
len(data.output_token_ids)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
0].sampling_params.max_tokens
if len(data.output_token_ids) < max_output_len - 1:
# add a place holder for prepare_decode
# arbitrary value, this could be any token
dummy_token = (540, )
data.output_token_ids += (dummy_token)
Expand All @@ -2195,6 +2213,7 @@ def execute_model(
if num_steps == 1:
return [output]
else:
try_revert_dummy_output_tokens()
return []

result = self._prepare_decode(seq_group_metadata_list,
Expand All @@ -2213,6 +2232,8 @@ def execute_model(
"attn_metadata": vars(result.attn_metadata)
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
else:
try_revert_dummy_output_tokens()

if self.is_driver_worker and self.profiler.enabled:
# Stop recording 'execute_model' event
Expand Down

0 comments on commit c3c0e90

Please sign in to comment.