Skip to content

Commit

Permalink
Merge habana_main into private/jmaksymczuk/fake_hpu_cpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaksymczuk committed Sep 12, 2024
2 parents d4efdba + f858d43 commit a0f9f3c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 31 deletions.
93 changes: 64 additions & 29 deletions tests/lora/test_lora_hpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from vllm.hpu.ops import LoraMask
from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice

from .utils import DummyLoRAManager
Expand All @@ -19,7 +20,19 @@
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
MAX_LORAS = 8


def createLoraMask(indices, batch_size, seq_len, max_loras, max_lora_rank,
lora_dtype):
indices = indices.view(-1, 1)
mask = torch.arange(max_loras * max_lora_rank, device=indices.device)
mask = mask.view(1, -1)
mask = ((mask >= ((indices) * max_lora_rank)) *
(mask < ((indices + 1) * max_lora_rank))).to(dtype=lora_dtype)
mask = mask.view(batch_size, 1,
-1).expand(batch_size, seq_len,
-1).reshape(batch_size * seq_len, -1)
return mask


@pytest.mark.parametrize("m", TENSOR_SIZES)
Expand All @@ -39,32 +52,40 @@ def test_apply_lora(m, n, k, rank, dtype) -> None:
input = torch.rand(k, n, device="hpu", dtype=dtype)
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling

lora_a_stack = torch.zeros(MAX_LORAS + 1,
lora_a_stack = torch.zeros(8,
1,
lora.lora_a.shape[1],
lora.lora_a.shape[0],
device="hpu",
dtype=dtype)
lora_b_stack = torch.zeros(MAX_LORAS + 1,
lora_b_stack = torch.zeros(8,
1,
lora.lora_b.shape[1],
lora.lora_b.shape[0],
device="hpu",
dtype=dtype)
for i in range(MAX_LORAS):
for i in range(lora_a_stack.shape[0]):
lora_a_stack[i][0] = lora.lora_a.T
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"),
output)
indices = torch.randint(0,
lora_a_stack.shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora(input, lora_a_stack, lora_b_stack, indices, output)

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.full((len(input), ), -1, device="hpu"), output)
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora(input, lora_a_stack, lora_b_stack, indices, output)
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()
Expand Down Expand Up @@ -99,39 +120,46 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
dim=1)

lora_a_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_1.lora_a.shape[1],
lora_1.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
lora_1.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(MAX_LORAS):
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_1.lora_a.T
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
lora_a_stacks[1][i][0] = lora_2.lora_a.T
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output,
(m // 2, m // 2))
indices = torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (m // 2, m // 2))

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (m // 2, m // 2))
assert torch.allclose(torch.zeros_like(output), output)

Expand Down Expand Up @@ -166,36 +194,36 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
dim=1)

lora_a_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_q.lora_a.shape[1],
lora_q.lora_a.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_k.lora_a.shape[1],
lora_k.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
lora_q.lora_b.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
lora_k.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(MAX_LORAS):
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_q.lora_a.T
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
lora_a_stacks[1][i][0] = lora_k.lora_a.T
Expand All @@ -204,17 +232,24 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T

output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output,
(qkv[0], qkv[1], qkv[2]))
indices = torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (qkv[0], qkv[1], qkv[2]))

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (qkv[0], qkv[1], qkv[2]))
assert torch.allclose(torch.zeros_like(output), output)

Expand Down
6 changes: 4 additions & 2 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,12 @@ def pad_list(list, k, v):

class HpuModelAdapter():

def __init__(self, model, block_size, enforce_eager):
def __init__(self, model, block_size, dtype, enforce_eager):
self.model = model
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true']
self.block_size = block_size
self.dtype = dtype
if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
) and not enforce_eager:
self.model = torch.compile(self.model,
Expand Down Expand Up @@ -305,7 +306,7 @@ def forward(self, *args, **kwargs):
input_ids = kwargs['input_ids']
kwargs['attn_metadata'] = self._update_metadata(
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
input_ids.device, torch.bfloat16)
input_ids.device, self.dtype)
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
Expand Down Expand Up @@ -603,6 +604,7 @@ def load_model(self) -> None:
self.model = _maybe_wrap_in_hpu_graph(
self.model,
self.block_size,
dtype=self.model_config.dtype,
enforce_eager=self.enforce_eager)
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
logger.info(msg)
Expand Down

0 comments on commit a0f9f3c

Please sign in to comment.