Skip to content

Commit

Permalink
Fix Lora Rebase (#290)
Browse files Browse the repository at this point in the history
Fixes Lora Related issues in vllm Rebase
  • Loading branch information
hlahkar authored Sep 20, 2024
1 parent 8e41fb5 commit b2653ab
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 164 deletions.
108 changes: 75 additions & 33 deletions tests/lora/test_lora_hpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
import torch
from vllm.hpu.ops import LoraMask

from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice
from vllm.hpu.punica_hpu import GaudiPunicaWrapper

from .utils import DummyLoRAManager

Expand All @@ -19,7 +20,19 @@
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
MAX_LORAS = 8


def createLoraMask(indices, batch_size, seq_len, max_loras, max_lora_rank,
lora_dtype):
indices = indices.view(-1, 1)
mask = torch.arange(max_loras * max_lora_rank, device=indices.device)
mask = mask.view(1, -1)
mask = ((mask >= ((indices) * max_lora_rank)) *
(mask < ((indices + 1) * max_lora_rank))).to(dtype=lora_dtype)
mask = mask.view(batch_size, 1,
-1).expand(batch_size, seq_len,
-1).reshape(batch_size * seq_len, -1)
return mask


@pytest.mark.parametrize("m", TENSOR_SIZES)
Expand All @@ -39,32 +52,41 @@ def test_apply_lora(m, n, k, rank, dtype) -> None:
input = torch.rand(k, n, device="hpu", dtype=dtype)
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling

lora_a_stack = torch.zeros(MAX_LORAS + 1,
lora_a_stack = torch.zeros(8,
1,
lora.lora_a.shape[1],
lora.lora_a.shape[0],
device="hpu",
dtype=dtype)
lora_b_stack = torch.zeros(MAX_LORAS + 1,
lora_b_stack = torch.zeros(8,
1,
lora.lora_b.shape[1],
lora.lora_b.shape[0],
device="hpu",
dtype=dtype)
for i in range(MAX_LORAS):
for i in range(lora_a_stack.shape[0]):
lora_a_stack[i][0] = lora.lora_a.T
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"),
output)
indices = torch.randint(0,
lora_a_stack.shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)
punica_wrapper = GaudiPunicaWrapper(4096, max_batches=256, device="hpu")

punica_wrapper.add_lora(output, input, lora_a_stack, lora_b_stack, 1.0)

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.full((len(input), ), -1, device="hpu"), output)
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

punica_wrapper.add_lora(output, input, lora_a_stack, lora_b_stack, 1.0)
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()
Expand Down Expand Up @@ -99,40 +121,48 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
dim=1)

lora_a_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_1.lora_a.shape[1],
lora_1.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
lora_1.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(MAX_LORAS):
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_1.lora_a.T
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
lora_a_stacks[1][i][0] = lora_2.lora_a.T
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output,
(m // 2, m // 2))
indices = torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

punica_wrapper = GaudiPunicaWrapper(4096, max_batches=256, device="hpu")
punica_wrapper.add_lora_packed_nslice(output, input, lora_a_stacks,
lora_b_stacks, 1.0, (m // 2, m // 2))

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
output, (m // 2, m // 2))
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

punica_wrapper.add_lora_packed_nslice(output, input, lora_a_stacks,
lora_b_stacks, 1.0, (m // 2, m // 2))
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()
Expand Down Expand Up @@ -166,36 +196,36 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
dim=1)

lora_a_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_q.lora_a.shape[1],
lora_q.lora_a.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_k.lora_a.shape[1],
lora_k.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
lora_q.lora_b.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
lora_k.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(MAX_LORAS):
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_q.lora_a.T
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
lora_a_stacks[1][i][0] = lora_k.lora_a.T
Expand All @@ -204,18 +234,30 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T

output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output,
(qkv[0], qkv[1], qkv[2]))
indices = torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

punica_wrapper = GaudiPunicaWrapper(4096, max_batches=256, device="hpu")
punica_wrapper.add_lora_packed_nslice(output, input,
lora_a_stacks,
lora_b_stacks,
1.0, (qkv[0], qkv[1], qkv[2]))

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
output, (qkv[0], qkv[1], qkv[2]))
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

punica_wrapper.add_lora_packed_nslice(output, input,
lora_a_stacks,
lora_b_stacks,
1.0, (qkv[0], qkv[1], qkv[2]))
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()
manager.reset_lora()
2 changes: 0 additions & 2 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def dispatch_bgmv_linear(
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indices: torch.LongTensor,
layer_idx: int,
scale: float,
):
Expand Down Expand Up @@ -228,7 +227,6 @@ def dispatch_bgmv_embedding(
y: torch.Tensor,
x: torch.Tensor,
wb_t_all: torch.Tensor,
indices: torch.LongTensor,
layer_idx: int,
scale: float,
):
Expand Down
77 changes: 77 additions & 0 deletions vllm/hpu/punica_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
###############################################################################

from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union

import torch
from vllm.lora.punica import PunicaWrapper
from vllm.hpu.ops import dispatch_bgmv_linear, dispatch_bgmv_embedding

class GaudiPunicaWrapper(PunicaWrapper):
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: str):
super().__init__(max_num_batched_tokens, max_batches, device)

def add_lora(self,
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
scale: float,
y_offset: Optional[int] = None,
y_slice_size: Optional[int] = None,
*,
buffer: Optional[torch.Tensor] = None) -> None:
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
dispatch_bgmv_linear(y, x, wa_t_all, wb_t_all, 0, 1.0)
y = y.view_as(y_org)

def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor],
lora_b_stacked: Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor],
scale: float,
output_slices: Tuple[int, ...]) -> None:
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
offset_left = 0

for slice_idx in range(len(output_slices)):
dispatch_bgmv_linear(
y[:, offset_left:offset_left + output_slices[slice_idx]],
x, lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, 1.0)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)

def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None) -> None:
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
dispatch_bgmv_linear(y, x, wa_t_all, wb_t_all, 0, 1.0)
y = y.view_as(y_org)

def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_input: bool = True,
):
dispatch_bgmv_embedding(y, x, w_t_all, 0, 1.0)
Loading

0 comments on commit b2653ab

Please sign in to comment.