diff --git a/tests/lora/test_long_context_hpu.py b/tests/lora/test_long_context_hpu.py new file mode 100644 index 0000000000000..33250edde00d3 --- /dev/null +++ b/tests/lora/test_long_context_hpu.py @@ -0,0 +1,304 @@ +import ast +from typing import List, Optional, Tuple + +import numpy as np +import pytest + +import vllm +from vllm import SamplingParams +from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding) + +from .data.long_context_test_data import prompts_and_responses + +context_len_to_scaling_factor = { + "16k": 4, + "32k": 8, +} + +# We use the same sampling params for all requests +sampling_params = SamplingParams( + temperature=0, + max_tokens=100, +) + + +def _create_lora_request(lora_id, long_context_infos): + context_len = long_context_infos[lora_id]["context_length"] + scaling_factor = context_len_to_scaling_factor[context_len] + return LoRARequest(f'{context_len}_{lora_id}', lora_id, + long_context_infos[lora_id]["lora"], None, + 4096 * scaling_factor) + + +def evaluate_json_response(model_response, golden_response): + """Evaluates the model response against the golden response. + + Returns a score between 0 and 1, where 1 is a perfect match and 0 is no + match. The score quantifies how well the model is able to extract the + golden JSON from the long context. + """ + try: + model_response = ast.literal_eval(model_response) + except Exception as e: + raise ValueError( + f"Model response is not a valid JSON. Expected {golden_response}, " + f"got {model_response}") from e + + # Normally, we would flatten the dictionary and compare the values, but in + # this case, we know that the dictionary is only 2 levels deep + positive_values = 0 + total_values = 0 + # We look at all the attributes of the person that we are extracting a + # biography of and copmare them to the golden response + for person_attribute, person_attribute_value in golden_response.items(): + if person_attribute in model_response: + if isinstance(person_attribute_value, dict): + for (sub_attribute, + sub_attribute_value) in person_attribute_value.items(): + total_values += 1 + if sub_attribute in model_response[ + person_attribute] and model_response[ + person_attribute][ + sub_attribute] == sub_attribute_value: + positive_values += 1 + else: + total_values += 1 + if model_response[person_attribute] == person_attribute_value: + positive_values += 1 + else: + # We count a missing sub-dict as a single missed value. + total_values += 1 + + # Return a score between 0 and 1 + return positive_values / total_values + + +def generate( + llm: vllm.LLM, + inputs: Tuple[str, SamplingParams, Optional[LoRARequest]], +): + prompts, sampling_param, lora_request = inputs + outputs = llm.generate(prompts, sampling_param, lora_request=lora_request) + return outputs[0].outputs[0].text.strip() + + +def batched_generate( + llm: vllm.LLM, + inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], +): + for input in inputs: + prompt, sampling_param, lora_req = input + # Add requests to the engine and run the engine + llm._validate_and_add_requests(prompt, + sampling_param, + lora_request=lora_req, + prompt_adapter_request=None) + + outputs = llm._run_engine(use_tqdm=True) + return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] + + +@pytest.fixture(scope="module") +def lora_llm(long_context_infos): + scaling_factors = [ + context_len_to_scaling_factor[info["context_length"]] + for info in long_context_infos.values() + ] + + llm = vllm.LLM( + "meta-llama/Llama-2-13b-chat-hf", + enable_lora=True, + max_num_seqs=16, + max_loras=2, + long_lora_scaling_factors=tuple(scaling_factors), + max_num_batched_tokens=4096 * 8, + tensor_parallel_size=1, + enforce_eager=True, # TODO Remove after SW-205153 is fixed + dtype="bfloat16", + disable_async_output_proc=True, # TODO Remove after SW-204469 is fixed. + distributed_executor_backend="mp") + yield llm + del llm + + +def test_rotary_emb_replaced(dist_init): + """Verify rotary emb in all the layers are replaced""" + from vllm.engine.arg_utils import EngineArgs + from vllm.platforms import current_platform + if current_platform.is_hpu(): + from vllm.worker.hpu_model_runner import HPUModelRunner as ModelRunner + else: + from vllm.worker.model_runner import ModelRunner + engine_args = EngineArgs("meta-llama/Llama-2-7b-hf", + long_lora_scaling_factors=(4.0, ), + enable_lora=True) + engine_config = engine_args.create_engine_config() + model_runner = ModelRunner( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + lora_config=engine_config.lora_config, + is_driver_worker=True, + ) + model_runner.load_model() + rotary_emb_count = 0 + model = model_runner.model.model if current_platform.is_hpu( + ) else model_runner.model + for module_name, module in model.named_modules(remove_duplicate=False): + if "rotary_emb" in module_name: + if "base_layer" not in module_name: + rotary_emb_count += 1 + assert isinstance(module, LinearScalingRotaryEmbeddingWithLora) + else: + assert isinstance(module, LinearScalingRotaryEmbedding) + # Llama 2 has 32 layers. + assert rotary_emb_count == 32 + + +@pytest.mark.skip_global_cleanup +def test_batched_rope_kernel(lora_llm, long_context_infos): + """We test the batched kernel by comparing the results of batched an + non-batched generation. + """ + # Create non batched results first to compare against batched results + non_batched_results: List[str] = [] + + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_prompt = (prompts_and_responses[context_len][0]["prompt"], + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + lora_output = generate(lora_llm, lora_prompt) + non_batched_results.append(lora_output) + + # Create batched results + # Each element of the batch must be + # (prompt, prompt_sampling_params, prompt_lora_request) + batched_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]] = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for non_batched, batched in zip(non_batched_results, batched_results): + assert non_batched == batched, ( + "Non batched and batched results should be the " + f"same:\n{batched}\n{non_batched}") + + +@pytest.mark.skip_global_cleanup +def test_self_consistency(lora_llm, long_context_infos): + """We test consistency of the batched kernel by permuting batched + inputs and comparing the results to the non-permuted batched results. + """ + num_loras = len(long_context_infos) + + # Create results in order of long_context_infos + batched_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]] = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + batched_results = batched_generate(lora_llm, batched_prompts) + + permutation = np.random.default_rng(seed=42).permutation(num_loras) + + # Create results in random order of permutation + batched_prompts = [] + for i in permutation: + lora_id, info = list(long_context_infos.items())[i] + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + + permutated_batched_results = batched_generate(lora_llm, batched_prompts) + + # Results should be the same + for i in range(num_loras): + assert batched_results[i] == permutated_batched_results[ + permutation[i]], ( + f"Results should be the same:\n{batched_results[i]}" + f"\n{permutated_batched_results[permutation[i]]}") + + +@pytest.mark.skip_global_cleanup +def test_quality(lora_llm, long_context_infos): + """We test the quality of the answers given by the LoRA model by + comparing the generated text to the merged model's outputs. + + This is effectively a mini-benchmark over four prompts. + If this test fails, this indicates that the quality of the LoRA model + is suboptimal compared to the merged model. For example, if the model + does not output valid dictionaries, this test will fail. + + If needed for testing, the merged versions of the models are available + as part of the `conftest`. + + The test is expected to run for about 1 minute on a p4de.24xlarge + instance. + """ + scores: List[float] = [] + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + for prompt_and_response in prompts_and_responses[context_len]: + lora_prompt = (prompt_and_response["prompt"], sampling_params, + _create_lora_request(lora_id, long_context_infos)) + response = generate(lora_llm, lora_prompt) + golden_answer = prompt_and_response["golden_answer"] + score = evaluate_json_response(response, golden_answer) + scores.append(score) + assert score > 0.3, ("Quality of the answer is not good enough. " + f"Expected {golden_answer}, got {response}") + assert np.mean(scores) > 0.5 + + +@pytest.mark.skip_global_cleanup +def test_max_len(lora_llm, long_context_infos): + """Test that we raise an ValueError when the input of a given LoRA + model exceeds the maximum length.""" + # Since each LoRA model has a different maximum length, we need to + # test each one separately + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + lora_request = _create_lora_request(lora_id, long_context_infos) + # Good prompt should be fine + good_prompt = prompts_and_responses[context_len][0]["prompt"] + generate(lora_llm, (good_prompt, sampling_params, lora_request)) + # Bad prompt should raise an error + bad_prompt = good_prompt * 2 + with pytest.raises(ValueError): + generate(lora_llm, (bad_prompt, sampling_params, lora_request)) + + # Also test batched + batched_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]] = [] + for lora_id_with_bad_inputs in long_context_infos: + for lora_id, info in long_context_infos.items(): + context_len = info["context_length"] + batched_prompts.extend([ + (prompts_and_responses[context_len][0]["prompt"] * + (2 if lora_id == lora_id_with_bad_inputs else 1), + sampling_params, + _create_lora_request(lora_id, long_context_infos)) + ]) + # Turn good prompt into bad prompt inside of batched prompts + + with pytest.raises(ValueError): + batched_generate(lora_llm, batched_prompts) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index f22f92b6fe64b..1fdd15df99c19 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -103,10 +103,15 @@ def convert_mapping( embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None + + from vllm.platforms import current_platform if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=get_device(), - dtype=torch.long) + if current_platform.is_hpu(): + long_lora_offsets_list: List[int] = [] + else: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device=get_device(), + dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -119,10 +124,18 @@ def convert_mapping( embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx if long_lora_context: - assert long_lora_offsets is not None lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset + if current_platform.is_hpu(): + long_lora_offsets_list.append(lora_offset) + else: + assert long_lora_offsets is not None + long_lora_offsets[i] = lora_offset + + if long_lora_context and current_platform.is_hpu(): + long_lora_offsets = torch.tensor(long_lora_offsets_list, + device=get_device(), + dtype=torch.long) indices_list: List[Union[List[int], torch.Tensor]] = [ index_mapping_indices, diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 382a0abb21240..b5100491c4135 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -37,6 +37,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import supports_multimodal from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sampling_params import SamplingParams @@ -649,12 +650,30 @@ def load_model(self) -> None: assert hasattr( self.model, "embedding_padding_modules" ), "Model does not have embedding_padding_modules" + + if supports_multimodal(self.model): + logger.warning( + "Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + # It's necessary to distinguish between the + # max_position_embeddings of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = ( + self.model.config.max_position_embeddings) + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, - self.vocab_size, self.lora_config, self.device, + self.vocab_size, + self.lora_config, + self.device, self.model.embedding_modules, - self.model.embedding_padding_modules) + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) self.model = self.lora_manager.create_lora_manager(self.model) if self.model_config.quantization == 'inc': @@ -1314,7 +1333,8 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [None] * num_layers max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] - max_batch_size = self.max_num_batched_tokens // max_seq_len + max_batch_size = min(self.max_num_batched_tokens // max_seq_len, + self.scheduler_config.max_num_seqs) self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, False, True) @@ -1333,7 +1353,6 @@ def warmup_scenario(self, f"bs{batch_size}_" f"seq{seq_len}_" f"graphs{'T' if use_graphs else 'F'}") - max_num_seqs = self.scheduler_config.max_num_seqs # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -1355,7 +1374,7 @@ def warmup_scenario(self, dummy_lora_requests.append(dummy_lora_request) dummy_lora_requests_per_seq = [ dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) + for idx in range(batch_size) ] self.profiler.start('internal', scenario_name) times = 3 if use_graphs or is_pt_profiler_run else 1