Skip to content

Commit

Permalink
added specific long context changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rsshaik1 committed Oct 25, 2024
1 parent d265a7f commit 8917184
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 53 deletions.
92 changes: 52 additions & 40 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch

import habana_frameworks.torch.core as htcore
import pytest
import torch
import torch.nn.functional as F
import habana_frameworks.torch.core as htcore

from vllm.config import LoRAConfig
from vllm_hpu_extension.ops import LoraMask
from vllm_hpu_extension.punica_hpu import GaudiPunicaWrapper

from vllm.config import LoRAConfig
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
Expand Down Expand Up @@ -42,8 +42,8 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed
from vllm.utils import seed_everything
from vllm.platforms import current_platform
from vllm.utils import seed_everything

from .utils import DummyLoRAManager

Expand Down Expand Up @@ -242,18 +242,19 @@ def create_random_embedding_layer():
layer=lora_embedding,
layer_weights=embedding.weight.T,
)

htcore.mark_step()
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3,
input_size=(200, ),
input_range=(1, vocab_size),
)

indices_list = [id_to_index.index(value) for value in index_mapping]
indices = torch.tensor(indices_list)
mask = createLoraMask(indices, indices.shape[0], 1, max_loras, 8, torch.bfloat16)
mask = createLoraMask(indices, indices.shape[0], 1, max_loras, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand Down Expand Up @@ -294,8 +295,9 @@ def create_random_embedding_layer():
input_size=(200, ),
input_range=(1, vocab_size),
)
indices = torch.full((len(inputs)*len(inputs[0]), ), 0, device="hpu")
mask = createLoraMask(indices, indices.shape[0], 1, 8, 8, torch.bfloat16)
indices = torch.full((len(inputs) * len(inputs[0]), ), 0, device="hpu")
mask = createLoraMask(indices, indices.shape[0], 1, 8, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand Down Expand Up @@ -385,7 +387,8 @@ def create_random_embedding_layer():
)
indices_list = [id_to_index.index(value) for value in index_mapping]
indices = torch.tensor(indices_list)
mask = createLoraMask(indices, indices.shape[0], 1, max_loras, 8, torch.bfloat16)
mask = createLoraMask(indices, indices.shape[0], 1, max_loras, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand Down Expand Up @@ -443,8 +446,9 @@ def create_random_embedding_layer():
input_size=(200, ),
input_range=(1, vocab_size),
)
indices = torch.full((len(inputs)*len(inputs[0]), ), 0, device="hpu")
mask = createLoraMask(indices, indices.shape[0], 1, 8, 8, torch.bfloat16)
indices = torch.full((len(inputs) * len(inputs[0]), ), 0, device="hpu")
mask = createLoraMask(indices, indices.shape[0], 1, 8, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

original_inputs = deepcopy(inputs)
Expand Down Expand Up @@ -524,7 +528,8 @@ def _pretest():
)
indices_list = [id_to_index.index(value) for value in index_mapping]
indices = torch.tensor(indices_list)
mask = createLoraMask(indices, indices.shape[0], 1, max_loras, 8, torch.bfloat16)
mask = createLoraMask(indices, indices.shape[0], 1, max_loras, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand Down Expand Up @@ -576,8 +581,9 @@ def _pretest():
input_range=(0, 1),
input_type=torch.bfloat16,
)
indices = torch.full((len(inputs)*len(inputs[0]), ), 0, device="hpu")
mask = createLoraMask(indices, indices.shape[0], 1, 8, 8, torch.bfloat16)
indices = torch.full((len(inputs) * len(inputs[0]), ), 0, device="hpu")
mask = createLoraMask(indices, indices.shape[0], 1, 8, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand Down Expand Up @@ -657,7 +663,8 @@ def create_random_linear_replicated_layer():
)
indices_list = [id_to_index.index(value) for value in index_mapping]
indices = torch.tensor(indices_list)
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8, torch.bfloat16)
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand Down Expand Up @@ -701,7 +708,8 @@ def create_random_linear_replicated_layer():
input_type=torch.bfloat16,
)
indices = torch.full((len(inputs), ), 0, device="hpu")
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8, torch.bfloat16)
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand All @@ -722,8 +730,7 @@ def create_random_linear_replicated_layer():


@torch.inference_mode()
# @pytest.mark.skip(
# reason="Fails when fully_shard is True.")
# @pytest.mark.skip(reason="Fails when fully_shard is True.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("fully_shard", [True, False])
Expand All @@ -732,9 +739,9 @@ def create_random_linear_replicated_layer():
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device, stage) -> None:

if(fully_shard == True):
if ('fully_shard' == True):
pytest.skip("Skipping the test when fully_shard is True")

torch.set_default_device(torch.device("hpu"))
if current_platform.is_hpu():
punica_wrapper = GaudiPunicaWrapper(8192, 256, device="hpu")
Expand Down Expand Up @@ -789,7 +796,8 @@ def create_random_linear_parallel_layer():
)
indices_list = [id_to_index.index(value) for value in index_mapping]
indices = torch.tensor(indices_list)
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8, torch.bfloat16)
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand Down Expand Up @@ -833,7 +841,8 @@ def create_random_linear_parallel_layer():
input_type=torch.bfloat16,
)
indices = torch.full((len(inputs), ), 0, device="hpu")
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8, torch.bfloat16)
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand All @@ -854,8 +863,7 @@ def create_random_linear_parallel_layer():


@torch.inference_mode()
# @pytest.mark.skip(
# reason="Fails when fully_shard is True.")
# @pytest.mark.skip(reason="Fails when fully_shard is True.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("fully_shard", [True, False])
Expand All @@ -864,13 +872,13 @@ def create_random_linear_parallel_layer():
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device, stage) -> None:

if(fully_shard == True):
if ('fully_shard' == True):
pytest.skip("Skipping the test when fully_shard is True")

torch.set_default_device(torch.device("hpu"))
if current_platform.is_hpu():
punica_wrapper = GaudiPunicaWrapper(8192, 256, device="hpu")
else:
else:
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
Expand Down Expand Up @@ -943,7 +951,8 @@ class FakeConfig:
)
indices_list = [id_to_index.index(value) for value in index_mapping]
indices = torch.tensor(indices_list)
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8, torch.bfloat16)
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand Down Expand Up @@ -989,7 +998,8 @@ class FakeConfig:
input_type=torch.bfloat16,
)
indices = torch.full((len(inputs), ), 0, device="hpu")
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8, torch.bfloat16)
mask = createLoraMask(indices, len(inputs), 1, max_loras, 8,
torch.bfloat16)
LoraMask.setLoraMask(mask)

lora_mapping = LoRAMapping(index_mapping,
Expand Down Expand Up @@ -1050,28 +1060,30 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
num_heads = 7

# Verify lora is equivalent to linear scaling rotary embedding.
rope = get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype = torch.bfloat16
)
rope = get_rope(head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype=torch.bfloat16)
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
lora_rope.set_mapping(punica_wrapper)
lora_rope.create_lora_weights(max_loras, lora_config)
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
linear_rope = get_rope(head_size,
rotary_dim,
max_position,
base,
is_neox_style, {
"type": "linear",
"factor": scaling_factors
}, dtype=torch.bfloat16)
},
dtype=torch.bfloat16)
#linear_rope = linear_rope.to(dtype=dtype)
id_to_index = get_random_id_to_index(num_loras, max_loras)
_, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=batch_size,
input_size=(seq_len, max_position),
input_size=(1, max_position),
input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.bfloat16,
)
Expand Down
16 changes: 8 additions & 8 deletions vllm/lora/punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,8 @@ def convert_mapping(
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=get_device(),
dtype=torch.long)
long_lora_offsets_list: List[int] = []

prompt_mapping: List[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
Expand All @@ -119,11 +116,14 @@ def convert_mapping(
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
long_lora_offsets_list.append(lora_offset)

long_lora_offsets = torch.tensor(long_lora_offsets_list,
device=get_device(),
dtype=torch.long)
# breakpoint()
indices_list: List[Union[List[int], torch.Tensor]] = [
index_mapping_indices,
lora_indices,
Expand Down Expand Up @@ -607,4 +607,4 @@ def add_lora_logits(self,

bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
y = y.view_as(y_org)
y = y.view_as(y_org)
29 changes: 24 additions & 5 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -649,12 +650,30 @@ def load_model(self) -> None:
assert hasattr(
self.model, "embedding_padding_modules"
), "Model does not have embedding_padding_modules"

if supports_multimodal(self.model):
logger.warning(
"Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
# It's necessary to distinguish between the
# max_position_embeddings of VLMs and LLMs.
if hasattr(self.model.config, "max_position_embeddings"):
max_pos_embeddings = (
self.model.config.max_position_embeddings)
else:
max_pos_embeddings = (
self.model.config.text_config.max_position_embeddings)

self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,
self.vocab_size, self.lora_config, self.device,
self.vocab_size,
self.lora_config,
self.device,
self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model.embedding_padding_modules,
max_position_embeddings=max_pos_embeddings,
)
self.model = self.lora_manager.create_lora_manager(self.model)

if self.model_config.quantization == 'inc':
Expand Down Expand Up @@ -1314,7 +1333,8 @@ def profile_run(self) -> None:
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = self.max_num_batched_tokens // max_seq_len
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)

self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
Expand All @@ -1333,7 +1353,6 @@ def warmup_scenario(self,
f"bs{batch_size}_"
f"seq{seq_len}_"
f"graphs{'T' if use_graphs else 'F'}")
max_num_seqs = self.scheduler_config.max_num_seqs
# This represents the maximum number of different requests
# that will have unique loras, an therefore the max amount of memory
# consumption create dummy lora request copies from the lora request
Expand All @@ -1355,7 +1374,7 @@ def warmup_scenario(self,
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
for idx in range(batch_size)
]
self.profiler.start('internal', scenario_name)
times = 3 if use_graphs or is_pt_profiler_run else 1
Expand Down

0 comments on commit 8917184

Please sign in to comment.