From 0010b766d90c1fab3547365852cb10e8d5cdda1f Mon Sep 17 00:00:00 2001 From: Sanju C Sudhakaran Date: Mon, 18 Nov 2024 09:43:52 +0200 Subject: [PATCH] Clean-up LoRA flow ... by removing unnecessary functions / variables --- vllm/lora/layers.py | 1 - vllm/lora/models.py | 112 +------------------------------- vllm/worker/hpu_model_runner.py | 9 --- 3 files changed, 1 insertion(+), 121 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index d11d46bd84162..d0c1cbdb0fd5f 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -231,7 +231,6 @@ def set_lora( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - embeddings_indices = None embeddings_indices = self.punica_wrapper.embeddings_indices indices = embeddings_indices[1].view_as(x) full_lora_a_embeddings = F.embedding( diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6aef981a6589b..2c3b80253a3b3 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,7 @@ import os import re from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type import safetensors.torch import torch @@ -51,116 +51,6 @@ class LongContextLoRAContext: offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) -def convert_mapping( - mapping: LoRAMapping, - lora_index_to_id: List[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional[LongContextLoRAContext] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: - """Converts LoRAMapping to index tensors. - - Args: - mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_index_to_id: List mapping LoRA ids to LoRA indices. - max_loras: Maximum number of LoRAs. - vocab_size: Model vocab size. - extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. - - Returns: - A tuple of tensors: - base_indices: Tensor of shape [batch_size] mapping batch rows to - LoRA indices. - sampler_indices: Tensor of shape [batch_size] mapping requests to - LoRA indices for sampler. For generation, this will be the - same as base_indicies. For prefill, this will map requests - to LoRA indices. - sampler_indices_padded: Tensor of shape [batch_size] mapping - requests to LoRA indices for sampler with padding. - Same as sampler_indicies, but -1 is replaced with - max_loras. - embeddings_indices: Tensor of shape [2, batch_size] mapping - requests to embedding indices. First row is for embeddings - added by the LoRAs, second row is for the LoRA.lora_a - embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. - indices_len: List of lengths of the above tensors. - Used to index into each tensor. It contains length for - (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). If long_lora doesn't - exist, it only contains first 4 entries. - """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() - embedding_indices = index_mapping_indices.copy() - lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - device = "hpu" if current_platform.is_hpu() else "cuda" - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) - prompt_mapping: List[int] = [ - lora_index_to_id.index(x) if x > 0 else -1 - for x in mapping.prompt_mapping - ] - lora_idx = None - for i in range(len(index_mapping_indices)): - # TODO index can be slow. optimize - lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) - if index_mapping_indices[i] > 0 else -1) - embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 - lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset - - indices_list: List[Union[List[int], torch.Tensor]] = [ - index_mapping_indices, lora_indices, embedding_indices - ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device=device) - prompt_mapping_tensor = torch.tensor(prompt_mapping, - device=device, - dtype=torch.long) - embeddings_indices = torch.stack([ - indices[2] * extra_vocab_size, - indices[2] * (vocab_size + extra_vocab_size) - ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 - base_indices = indices[1] - sampler_indices = prompt_mapping_tensor - sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 - sampler_indices_padded = ( - torch.arange( - 0, len(sampler_indices_padded), device=device, dtype=torch.long) + - (sampler_indices_padded * len(sampler_indices_padded))) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] - # Contain length of indices tensors. Used to index into each tensor. - indices_len = [ - base_indices.shape[-1], sampler_indices.shape[-1], - sampler_indices_padded.shape[-1], embeddings_indices.shape[-1] - ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - - return (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices, indices_len) - - def get_lora_id(): global _GLOBAL_LORA_ID _GLOBAL_LORA_ID += 1 diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 97ad0a6893dd4..916169e36b99d 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1890,15 +1890,6 @@ def get_counter_dict(self, cache_config, duration, seq_len, return counters -def unwrap_model(model): - if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): - return unwrap_model(model._orig_mod) - else: - model = list(vars(model)['_modules'].values())[0] - modules = list(vars(model)['_modules'].values()) - return modules - - class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): """ GPU model runner with sampling step.