Skip to content

Commit

Permalink
Add PunicaWrapperHPU to handle LoRA computations
Browse files Browse the repository at this point in the history
  • Loading branch information
SanjuCSudhakaran committed Dec 11, 2024
1 parent 5a166da commit b8fff21
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 6 deletions.
3 changes: 3 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,8 +1070,11 @@ def _get_logits(
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
posinf=float("inf"),
neginf=float("-inf")))

# HPU needs special handling to prune out dummy samples
if current_platform.is_hpu():
lora_logits = lora_logits[:logits.shape[0], :]

logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1]] = lora_logits
Expand Down
87 changes: 87 additions & 0 deletions vllm/lora/punica_wrapper/punica_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Optional, Tuple, Union, final

import torch
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
dispatch_bgmv_linear)

from .punica_base import PunicaWrapperBase


@final
class PunicaWrapperHPU(PunicaWrapperBase):

def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
# Increasing max_num_batched_tokens by 3x to handle increase in
# tensor size due to padding.
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
max_batches, device)

def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_input: bool = True,
**kwargs) -> None:
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)

def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
scale: float,
output_slices: Tuple[int, ...],
*,
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
offset_left = 0

for slice_idx in range(len(output_slices)):
dispatch_bgmv_linear(
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)

def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
y = y.view_as(y_org)

def add_shrink(
self,
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> None:
raise NotImplementedError

def add_expand(
self,
y: torch.Tensor,
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
output_slices: Tuple[int, ...],
offset_start: int = 0,
add_input=True,
**kwargs,
) -> None:
raise NotImplementedError
9 changes: 4 additions & 5 deletions vllm/lora/punica_wrapper/punica_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
print_info_once("Using PunicaWrapperGPU.")
return PunicaWrapperGPU(*args, **kwargs)
if current_platform.is_hpu():
elif current_platform.is_hpu():
# Lazy import to avoid ImportError
from vllm_hpu_extension.punica_hpu import GaudiPunicaWrapper
print_info_once("Using GaudiPunicaWrapper.")
return GaudiPunicaWrapper(*args, **kwargs)

from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
print_info_once("Using PunicaWrapperHPU.")
return PunicaWrapperHPU(*args, **kwargs)
else:
raise NotImplementedError
5 changes: 4 additions & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,10 @@ def load_model(self) -> None:
assert hasattr(
self.model, "embedding_padding_modules"
), "Model does not have embedding_padding_modules"

assert not self.lora_config.bias_enabled, \
"Bias support in LoRA is not enabled in HPU yet."
assert not self.lora_config.fully_sharded_loras, \
"Fully sharded LoRAs is not enabled in HPU yet."
if supports_multimodal(self.model):
logger.warning(
"Regarding multimodal models, vLLM currently "
Expand Down

0 comments on commit b8fff21

Please sign in to comment.