Skip to content

Commit

Permalink
Add LoRA specific changes in MSS
Browse files Browse the repository at this point in the history
... to support LoRA + MSS flow
  • Loading branch information
SanjuCSudhakaran committed Jan 10, 2025
1 parent 73aaf71 commit c6a9a2f
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ def _prepare_decode(
device='cpu')
else:
real_batch_size = len(seq_group_metadata_list)
input_tokens = output[:real_batch_size]
input_tokens = output[:real_batch_size].clone()

input_positions = torch.tensor(input_positions,
dtype=torch.long,
Expand Down Expand Up @@ -2250,18 +2250,31 @@ def try_revert_dummy_output_tokens():

result = self._prepare_decode(seq_group_metadata_list,
output=output)
if self.lora_config:
lora_mapping = LoRAMapping(
**dict(index_mapping=result.lora_index_mapping,
prompt_mapping=result.lora_prompt_mapping,
is_prefill=False))
self.set_active_loras(result.lora_requests,
lora_mapping)
lora_mask, lora_logits_mask = self.create_lora_mask(
result.input_tokens, result.lora_ids, False)

execute_model_kwargs.update({
"input_ids":
result.input_tokens,
"positions":
result.input_positions,
"attn_metadata":
self.trim_attn_metadata(result.attn_metadata)
self.trim_attn_metadata(result.attn_metadata),
"lora_mask":
lora_mask,
})
model_kwargs_broadcast_data = {
"input_ids": result.input_tokens,
"positions": result.input_positions,
"attn_metadata": vars(result.attn_metadata)
"attn_metadata": vars(result.attn_metadata),
"lora_mask": lora_mask,
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
else:
Expand Down

0 comments on commit c6a9a2f

Please sign in to comment.