From 6eed0ef582a1f2168cb4f130d521e76530e48199 Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Thu, 7 Nov 2024 13:42:12 +0200 Subject: [PATCH] Handle offsets shape in long contexts --- tests/lora/test_layers_hpu.py | 9 ++++---- tests/lora/test_llama_hpu.py | 5 ++-- tests/lora/test_long_context_hpu.py | 23 ++++++++++--------- .../model_executor/layers/rotary_embedding.py | 3 ++- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/lora/test_layers_hpu.py b/tests/lora/test_layers_hpu.py index 7e33813c7a6a2..bbb544aa8ee2e 100644 --- a/tests/lora/test_layers_hpu.py +++ b/tests/lora/test_layers_hpu.py @@ -43,7 +43,6 @@ ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) from vllm.model_executor.utils import set_random_seed from vllm.platforms import current_platform -from vllm.utils import seed_everything from .utils import DummyLoRAManager @@ -1043,8 +1042,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, seq_len) -> None: dtype = torch.bfloat16 seed = 0 - seed_everything(seed) - torch.set_default_device(torch.device("hpu")) + current_platform.seed_everything(seed) + torch.set_default_device(device) if current_platform.is_hpu(): punica_wrapper = GaudiPunicaWrapper(8192, 256, device="hpu") else: @@ -1076,7 +1075,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": scaling_factors }, dtype=torch.bfloat16) @@ -1085,7 +1084,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, _, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=[0], num_inputs=batch_size, - input_size=(seq_len, max_position), + input_size=(1, max_position), input_range=(0, lora_config.lora_extra_vocab_size), input_type=torch.bfloat16, ) diff --git a/tests/lora/test_llama_hpu.py b/tests/lora/test_llama_hpu.py index 5571d727ef8e2..611380816b5b3 100644 --- a/tests/lora/test_llama_hpu.py +++ b/tests/lora/test_llama_hpu.py @@ -1,8 +1,7 @@ from typing import List -from conftest import cleanup - import vllm +from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest MODEL_PATH = "meta-llama/Llama-2-7b-hf" @@ -73,7 +72,7 @@ def _test_llama_lora(sql_lora_files, tp_size): assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output print("removing lora") - cleanup() + cleanup_dist_env_and_memory() def test_llama_lora_1x(sql_lora_files): diff --git a/tests/lora/test_long_context_hpu.py b/tests/lora/test_long_context_hpu.py index 3c3e1b7c1e41c..7bf10b23b4c66 100644 --- a/tests/lora/test_long_context_hpu.py +++ b/tests/lora/test_long_context_hpu.py @@ -28,9 +28,15 @@ def _create_lora_request(lora_id, long_context_infos): context_len = long_context_infos[lora_id]["context_length"] scaling_factor = context_len_to_scaling_factor[context_len] - return LoRARequest(f'{context_len}_{lora_id}', lora_id, - long_context_infos[lora_id]["lora"], None, - 4096 * scaling_factor) + return LoRARequest( + # There are 2 LoRAs for 16K, we need to add lora_id to indicate + # they are different LoRAs. + context_len + str(lora_id), + lora_id, + long_context_infos[lora_id]["lora"], + None, + 4096 * scaling_factor, + ) def evaluate_json_response(model_response, golden_response): @@ -117,7 +123,8 @@ def lora_llm(long_context_infos): max_num_batched_tokens=4096 * 8, tensor_parallel_size=1, dtype="bfloat16", - disable_async_output_proc=True, # TODO Remove after SW-204469 is fixed. + # FIXME enable async output processor + disable_async_output_proc=True, distributed_executor_backend="mp") yield llm del llm @@ -136,13 +143,7 @@ def test_rotary_emb_replaced(dist_init): enable_lora=True) engine_config = engine_args.create_engine_config() model_runner = ModelRunner( - model_config=engine_config.model_config, - parallel_config=engine_config.parallel_config, - scheduler_config=engine_config.scheduler_config, - device_config=engine_config.device_config, - cache_config=engine_config.cache_config, - load_config=engine_config.load_config, - lora_config=engine_config.lora_config, + vllm_config=engine_config, is_driver_worker=True, ) model_runner.load_model() diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 63ceec63e8317..b81ef6c03278b 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -203,9 +203,10 @@ def forward_hpu( ) -> Tuple[torch.Tensor, torch.Tensor]: from habana_frameworks.torch.hpex.kernels import ( RotaryPosEmbeddingMode, apply_rotary_pos_emb) - positions = positions.flatten() if offsets is not None: + offsets = offsets.view(positions.shape[0], -1) positions = positions + offsets + positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions).view( num_tokens, 1, -1)