Skip to content

Commit

Permalink
Handle LoRA specific changes in MSS (#675)
Browse files Browse the repository at this point in the history
This PR adds changes required to enable MSS with LoRA flow. Checked
there are no regressions using vllm-fork CI job
https://tf-jenkins-ctrl01.habana-labs.com/job/vLLM/view/CI-jobs/job/vLLM-CI-Pipeline/429/
  • Loading branch information
SanjuCSudhakaran authored Jan 11, 2025
1 parent 73aaf71 commit c5975f8
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ def _prepare_decode(
device='cpu')
else:
real_batch_size = len(seq_group_metadata_list)
input_tokens = output[:real_batch_size]
input_tokens = output[:real_batch_size].clone()

input_positions = torch.tensor(input_positions,
dtype=torch.long,
Expand Down Expand Up @@ -2250,18 +2250,31 @@ def try_revert_dummy_output_tokens():

result = self._prepare_decode(seq_group_metadata_list,
output=output)
if self.lora_config:
lora_mapping = LoRAMapping(
**dict(index_mapping=result.lora_index_mapping,
prompt_mapping=result.lora_prompt_mapping,
is_prefill=False))
self.set_active_loras(result.lora_requests,
lora_mapping)
lora_mask, lora_logits_mask = self.create_lora_mask(
result.input_tokens, result.lora_ids, False)

execute_model_kwargs.update({
"input_ids":
result.input_tokens,
"positions":
result.input_positions,
"attn_metadata":
self.trim_attn_metadata(result.attn_metadata)
self.trim_attn_metadata(result.attn_metadata),
"lora_mask":
lora_mask,
})
model_kwargs_broadcast_data = {
"input_ids": result.input_tokens,
"positions": result.input_positions,
"attn_metadata": vars(result.attn_metadata)
"attn_metadata": vars(result.attn_metadata),
"lora_mask": lora_mask,
}
broadcast_tensor_dict(model_kwargs_broadcast_data, src=0)
else:
Expand Down

0 comments on commit c5975f8

Please sign in to comment.