Skip to content

Commit

Permalink
Fix recompilations due to different batch_sizes in MSS (#637)
Browse files Browse the repository at this point in the history
Fix for batch size padding in multi-step scheduling by
@SanjuCSudhakaran.

Co-authored-by: Sanju C Sudhakaran <[email protected]>
  • Loading branch information
mfylcek and SanjuCSudhakaran authored Dec 16, 2024
1 parent 11c07e3 commit ba1d24b
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2019,6 +2019,19 @@ def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int],

return lora_mask, lora_logits_mask

def add_dummy_seq(self, seq_group_metadata_list, is_prompt):
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
seq_group_metadata_list = seq_group_metadata_list.copy()
if batch_size_padding > 0:
dummy_seq_group_metadata = self.create_dummy_seq_group_metadata(
0, 0, is_prompt)
seq_group_metadata_list.extend(dummy_seq_group_metadata
for _ in range(batch_size_padding))
return seq_group_metadata_list

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -2105,8 +2118,8 @@ def execute_model(
def try_revert_dummy_output_tokens():
if len(cache_orig_output_tokens_len) > 0:
# Reuse the original output token ids length
for i, seq_group_metadata in enumerate(
seq_group_metadata_list):
for i in range(len(cache_orig_output_tokens_len)):
seq_group_metadata = seq_group_metadata_list[i]
for j, data in seq_group_metadata.seq_data.items():
orig_output_tokens_len = \
cache_orig_output_tokens_len[i][j]
Expand Down Expand Up @@ -2184,16 +2197,18 @@ def try_revert_dummy_output_tokens():
else:
raise RuntimeError(
"seq_group_metadata_list is uninitialized")
for i, seq_group_metadata in enumerate(
for seq_idx, seq_group_metadata in enumerate(
seq_group_metadata_list):
# Skip empty steps
seq_group_metadata.state.current_step += (
num_steps - 2)
# Cache the original output token ids
cache_orig_output_tokens_len.append({})
for j, data in seq_group_metadata.seq_data.items():
cache_orig_output_tokens_len[i][j] = \
cache_orig_output_tokens_len[seq_idx][j] = \
len(data.output_token_ids)
seq_group_metadata_list = self.add_dummy_seq(
seq_group_metadata_list, is_prompt=False)
for seq_group_metadata in seq_group_metadata_list:
for data in seq_group_metadata.seq_data.values():
max_output_len = sampling_metadata.seq_groups[
Expand Down

0 comments on commit ba1d24b

Please sign in to comment.