Skip to content

Commit

Permalink
Add padding to encoder_seq_lens (HabanaAI#610)
Browse files Browse the repository at this point in the history
Without this change we can observe below error:
```
[rank0]:   File "/software/users/kdamaszke/repos/vllm-fork/vllm/model_executor/models/mllama.py", line 959, in forward
[rank0]:     full_text_row_masked_out_mask = full_text_row_masked_out_mask.view(
[rank0]: RuntimeError: shape '[4, -1, 1]' is invalid for input of size 3
```
It occurs when one of the requests is removed from the batch earlier. In
that case, language model is still working on the shapes padded to the
bucketed batch size, while encoder input doesn't. This change is
aligning the batch size on `encoder_seq_lens` to the expected one.
  • Loading branch information
kdamaszk authored Dec 12, 2024
1 parent 7ef6b2c commit 449a89d
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions vllm/worker/hpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,15 @@ def _prepare_encoder_model_input_tensors(
attn_metadata.cross_block_groups = block_groups
attn_metadata.cross_block_usage = block_usage

# add padding to align with language model shapes
real_batch_size = len(seq_group_metadata_list)
batch_size_padded = self.bucketing_ctx.get_padded_batch_size(
real_batch_size, is_prompt)
batch_size_padding = batch_size_padded - real_batch_size
if batch_size_padding > 0:
encoder_seq_lens.extend(encoder_seq_lens[0]
for _ in range(batch_size_padding))

encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
attn_metadata.encoder_seq_lens = encoder_seq_lens
attn_metadata.encoder_seq_lens_tensor = encoder_seq_lens_tensor
Expand Down

0 comments on commit 449a89d

Please sign in to comment.