Skip to content

Commit

Permalink
Handle offsets shape in long contexts (#477)
Browse files Browse the repository at this point in the history
This PR changes the view to `offset` tensor to (batch_size, -1) for
enabling broadcasting.
  • Loading branch information
vivekgoe authored Nov 11, 2024
2 parents 41dddab + 6eed0ef commit 1565944
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 20 deletions.
9 changes: 4 additions & 5 deletions tests/lora/test_layers_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import seed_everything

from .utils import DummyLoRAManager

Expand Down Expand Up @@ -1043,8 +1042,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
seq_len) -> None:
dtype = torch.bfloat16
seed = 0
seed_everything(seed)
torch.set_default_device(torch.device("hpu"))
current_platform.seed_everything(seed)
torch.set_default_device(device)
if current_platform.is_hpu():
punica_wrapper = GaudiPunicaWrapper(8192, 256, device="hpu")
else:
Expand Down Expand Up @@ -1076,7 +1075,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
max_position,
base,
is_neox_style, {
"type": "linear",
"rope_type": "linear",
"factor": scaling_factors
},
dtype=torch.bfloat16)
Expand All @@ -1085,7 +1084,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
_, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=batch_size,
input_size=(seq_len, max_position),
input_size=(1, max_position),
input_range=(0, lora_config.lora_extra_vocab_size),
input_type=torch.bfloat16,
)
Expand Down
5 changes: 2 additions & 3 deletions tests/lora/test_llama_hpu.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import List

from conftest import cleanup

import vllm
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest

MODEL_PATH = "meta-llama/Llama-2-7b-hf"
Expand Down Expand Up @@ -73,7 +72,7 @@ def _test_llama_lora(sql_lora_files, tp_size):
assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output

print("removing lora")
cleanup()
cleanup_dist_env_and_memory()


def test_llama_lora_1x(sql_lora_files):
Expand Down
23 changes: 12 additions & 11 deletions tests/lora/test_long_context_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@
def _create_lora_request(lora_id, long_context_infos):
context_len = long_context_infos[lora_id]["context_length"]
scaling_factor = context_len_to_scaling_factor[context_len]
return LoRARequest(f'{context_len}_{lora_id}', lora_id,
long_context_infos[lora_id]["lora"], None,
4096 * scaling_factor)
return LoRARequest(
# There are 2 LoRAs for 16K, we need to add lora_id to indicate
# they are different LoRAs.
context_len + str(lora_id),
lora_id,
long_context_infos[lora_id]["lora"],
None,
4096 * scaling_factor,
)


def evaluate_json_response(model_response, golden_response):
Expand Down Expand Up @@ -117,7 +123,8 @@ def lora_llm(long_context_infos):
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=1,
dtype="bfloat16",
disable_async_output_proc=True, # TODO Remove after SW-204469 is fixed.
# FIXME enable async output processor
disable_async_output_proc=True,
distributed_executor_backend="mp")
yield llm
del llm
Expand All @@ -136,13 +143,7 @@ def test_rotary_emb_replaced(dist_init):
enable_lora=True)
engine_config = engine_args.create_engine_config()
model_runner = ModelRunner(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
vllm_config=engine_config,
is_driver_worker=True,
)
model_runner.load_model()
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,10 @@ def forward_hpu(
) -> Tuple[torch.Tensor, torch.Tensor]:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode, apply_rotary_pos_emb)
positions = positions.flatten()
if offsets is not None:
offsets = offsets.view(positions.shape[0], -1)
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions).view(
num_tokens, 1, -1)
Expand Down

0 comments on commit 1565944

Please sign in to comment.