diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index eb882faf3974a..15e576cb065c7 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -51,6 +51,7 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] + # We will launch different triton kernels between the prefill and decode # stages, so we need to verify this. prefill stage(True) or decode stage(False) STAGES = [True, False] @@ -120,11 +121,12 @@ def populate_loras( subloras: List[LoRALayerWeights] = [] sublora_len = layer_weights.shape[0] // repeats for i in range(repeats): - sublora = DummyLoRAManager().init_random_lora( - module_name=f"fake_{i}", - weight=layer_weights, - generate_embeddings_tensor=generate_embeddings_tensor, - ) + sublora = DummyLoRAManager( + layer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) sublora.lora_b = sublora.lora_b[:, (sublora_len * i):(sublora_len * (i + 1))] sublora.optimize() @@ -152,6 +154,7 @@ def create_random_inputs( input_size: Tuple[int, ...], input_range: Tuple[float, float], input_type: torch.dtype = torch.int, + device: torch.device = "cuda" ) -> Tuple[List[torch.Tensor], List[int], List[int]]: """Creates random inputs. @@ -173,10 +176,14 @@ def create_random_inputs( for _ in range(num_inputs): if input_type == torch.int: inputs.append( - torch.randint(low=int(low), high=int(high), size=input_size)) + torch.randint(low=int(low), + high=int(high), + size=input_size, + device=device)) else: inputs.append( - torch.rand(size=input_size, dtype=input_type) * high + low) + torch.rand(size=input_size, dtype=input_type, device=device) * + high + low) lora_id = random.choice(active_lora_ids) index_mapping += [lora_id] * input_size[0] @@ -191,6 +198,10 @@ def create_random_inputs( @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: + # For multi-GPU testing of Triton kernel, we must explicitly set the CUDA + # device, see: https://github.com/triton-lang/triton/issues/2925 + # Same below. + torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 @@ -225,7 +236,7 @@ def create_random_embedding_layer(): num_inputs=num_loras * 3, input_size=(200, ), input_range=(1, vocab_size), - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -263,7 +274,7 @@ def create_random_embedding_layer(): num_inputs=num_loras * 3, input_size=(200, ), input_range=(1, vocab_size), - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -291,6 +302,7 @@ def create_random_embedding_layer(): def test_embeddings_with_new_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: + torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 punica_wrapper = PunicaWrapper(8192, 256, device) @@ -345,7 +357,7 @@ def create_random_embedding_layer(): num_inputs=num_loras * 3, input_size=(200, ), input_range=(1, vocab_size), - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -400,7 +412,7 @@ def create_random_embedding_layer(): num_inputs=num_loras * 3, input_size=(200, ), input_range=(1, vocab_size), - ) + device=device) original_inputs = deepcopy(inputs) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, @@ -426,6 +438,7 @@ def create_random_embedding_layer(): def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, stage) -> None: + torch.cuda.set_device(device) torch.set_default_device(device) max_loras = 8 punica_wrapper = PunicaWrapper(8192, 256, device) @@ -471,7 +484,7 @@ def _pretest(): input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -520,7 +533,7 @@ def _pretest(): input_size=(1, 1024), input_range=(0, 1), input_type=torch.float16, - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -554,6 +567,7 @@ def _pretest(): @pytest.mark.parametrize("stage", STAGES) def test_linear_replicated(dist_init, num_loras, device, stage) -> None: + torch.cuda.set_device(device) torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 @@ -592,7 +606,7 @@ def create_random_linear_replicated_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -631,7 +645,7 @@ def create_random_linear_replicated_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -658,6 +672,7 @@ def create_random_linear_replicated_layer(): def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, device, stage) -> None: + torch.cuda.set_device(device) torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 @@ -706,7 +721,7 @@ def create_random_linear_parallel_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -745,7 +760,7 @@ def create_random_linear_parallel_layer(): input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -772,6 +787,7 @@ def create_random_linear_parallel_layer(): def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, device, stage) -> None: + torch.cuda.set_device(device) torch.set_default_device(device) punica_wrapper = PunicaWrapper(8192, 256, device) max_loras = 8 @@ -842,7 +858,7 @@ class FakeConfig: input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -883,7 +899,7 @@ class FakeConfig: input_size=(1, 4096), input_range=(0, 1), input_type=torch.float16, - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) @@ -962,7 +978,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, input_size=(1, max_position), input_range=(0, lora_config.lora_extra_vocab_size), input_type=torch.float16, - ) + device=device) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) long_lora_context = LongContextLoRAContext(list(scaling_factors), diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 67cf298b4df2b..8d109b2c81503 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -25,8 +25,13 @@ EMBEDDING_PADDING_MODULES = ["lm_head"] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] -def test_from_lora_tensors(sql_lora_files): + +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_from_lora_tensors(sql_lora_files, device): tensors = load_file( os.path.join(sql_lora_files, "adapter_model.safetensors")) new_embeddings = load_file( @@ -36,7 +41,7 @@ def test_from_lora_tensors(sql_lora_files): 8, 16, tensors, - "cuda", + device, embeddings=new_embeddings, embedding_modules=EMBEDDING_MODULES, embedding_padding_modules=EMBEDDING_PADDING_MODULES) @@ -46,6 +51,8 @@ def test_from_lora_tensors(sql_lora_files): assert lora.lora_alpha == 16 assert lora.lora_a is not None assert lora.lora_b is not None + assert lora.lora_a.device == torch.device(device) + assert lora.lora_b.device == torch.device(device) assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" assert lora.lora_a.shape[1] == 8 @@ -60,8 +67,8 @@ def test_from_lora_tensors(sql_lora_files): assert lora.embeddings_tensor is None -def create_lora(lora_id: int, model: nn.Module, - sub_modules: List[str]) -> LoRAModel: +def create_lora(lora_id: int, model: nn.Module, sub_modules: List[str], + device: torch.device) -> LoRAModel: loras: Dict[str, LoRALayerWeights] = {} for name in sub_modules: w = model.get_submodule(name).weight @@ -69,8 +76,8 @@ def create_lora(lora_id: int, model: nn.Module, name, 8, 16, - torch.rand([w.shape[1], 8], device="cuda"), - torch.rand([8, w.shape[0]], device="cuda"), + torch.rand([w.shape[1], 8], device=device), + torch.rand([8, w.shape[0]], device=device), ) return LoRAModel(lora_id, 8, loras) @@ -80,6 +87,7 @@ def create_packed_lora( model: nn.Module, module_name, replaced_module_names, + device: torch.device, empty_replaced_module_name=None, ) -> LoRAModel: w = model.get_submodule(module_name).weight @@ -91,9 +99,9 @@ def create_packed_lora( replaced_module_name, 8, 16, - torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([w.shape[1], 8], device=device), torch.rand([8, w.shape[0] // len(replaced_module_names)], - device="cuda"), + device=device), ) return LoRAModel(lora_id, 8, loras) @@ -104,7 +112,8 @@ def test_replace_submodules(dist_init, dummy_model): model.packed_modules_mapping = {} manager = LoRAModelManager( model, 1, 1, 1, - LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8)) + LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), + torch.device("cuda")) model = manager.model assert isinstance(model.get_submodule("dense1"), @@ -116,16 +125,28 @@ def test_replace_submodules(dist_init, dummy_model): RowParallelLinearWithLoRA) -def test_lora_model_manager(dist_init, dummy_model): +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] model.packed_modules_mapping = {} - model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) - model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) - model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) - manager = LoRAModelManager( - model, 2, 2, 2, - LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) + model_lora1 = create_lora(1, + model, ["layer1.dense1", "dense2", "lm_head"], + device=device) + model_lora2 = create_lora(2, + model, ["dense1", "dense2", "lm_head"], + device=device) + model_lora3 = create_lora(3, + model, ["dense1", "dense2", "lm_head"], + device=device) + manager = LoRAModelManager(model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, + max_cpu_loras=3, + max_loras=2), + device=device) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) assert manager.activate_adapter(1) @@ -161,17 +182,32 @@ def test_lora_model_manager(dist_init, dummy_model): assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 + assert manager.device == device + assert manager.punica_wrapper.device == device -def test_lora_lru_cache_model_manager(dist_init, dummy_model): + +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] model.packed_modules_mapping = {} - model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) - model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) - model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) - manager = LRUCacheLoRAModelManager( - model, 2, 2, 2, - LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) + model_lora1 = create_lora(1, + model, ["layer1.dense1", "dense2", "lm_head"], + device=device) + model_lora2 = create_lora(2, + model, ["dense1", "dense2", "lm_head"], + device=device) + model_lora3 = create_lora(3, + model, ["dense1", "dense2", "lm_head"], + device=device) + manager = LRUCacheLoRAModelManager(model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, + max_cpu_loras=3, + max_loras=2), + device=device) assert all(x is None for x in manager.lora_index_to_id) assert manager.add_adapter(model_lora1) assert manager.activate_adapter(1) @@ -238,20 +274,37 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): with pytest.raises(ValueError): assert manager.pin_adapter(3) + assert manager.punica_wrapper.device == device + assert manager.device == device + -def test_lru_lora_model_manager(dist_init, dummy_model): +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager model = dummy_model model.supported_lora_modules = ["dense1", "dense2", "lm_head"] model.packed_modules_mapping = {} - model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) - model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) - model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) - model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"]) - manager = LRUCacheLoRAModelManager( - model, 2, 2, 2, - LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2)) + model_lora1 = create_lora(1, + model, ["layer1.dense1", "dense2", "lm_head"], + device=device) + model_lora2 = create_lora(2, + model, ["dense1", "dense2", "lm_head"], + device=device) + model_lora3 = create_lora(3, + model, ["dense1", "dense2", "lm_head"], + device=device) + model_lora4 = create_lora(4, + model, ["dense1", "dense2", "lm_head"], + device=device) + manager = LRUCacheLoRAModelManager(model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, + max_cpu_loras=2, + max_loras=2), + device=device) assert all(x is None for x in manager.lora_index_to_id) @@ -351,14 +404,17 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.remove_oldest_adapter() assert set(manager.list_adapters()) == {1} + assert manager.punica_wrapper.device == device + assert manager.device == device +@pytest.mark.parametrize("device", CUDA_DEVICES) def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): + sql_lora_files, device): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) worker_adapter_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), + lora_config.lora_extra_vocab_size, lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) worker_adapter_manager.create_lora_manager( llama_2_7b_model_extra_embeddings) @@ -426,14 +482,19 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, LoRARequest("14", 14, sql_lora_files) ], mapping) + assert worker_adapter_manager.device == device + assert (worker_adapter_manager._adapter_manager.punica_wrapper.device == + device) + +@pytest.mark.parametrize("device", CUDA_DEVICES) def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): + sql_lora_files, device): # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) worker_adapter_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), + lora_config.lora_extra_vocab_size, lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) worker_adapter_manager.create_lora_manager( llama_2_7b_model_extra_embeddings) @@ -497,8 +558,13 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, LoRARequest("14", 14, sql_lora_files) ], mapping) + assert worker_adapter_manager.device == device + assert (worker_adapter_manager._adapter_manager.punica_wrapper.device == + device) + -def test_packed_loras(dist_init, dummy_model_gate_up): +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_packed_loras(dist_init, dummy_model_gate_up, device): model = dummy_model_gate_up model.supported_lora_modules = ["gate_up_proj"] model.packed_modules_mapping = { @@ -511,18 +577,25 @@ def test_packed_loras(dist_init, dummy_model_gate_up): 1, model, module_name="gate_up_proj", - replaced_module_names=["gate_proj", "up_proj"]) + replaced_module_names=["gate_proj", "up_proj"], + device=device) model_lora1 = create_packed_lora( 2, model, module_name="gate_up_proj", replaced_module_names=["gate_proj", "up_proj"], + device=device, empty_replaced_module_name="gate_proj", ) - manager = LoRAModelManager( - model, 2, 2, 2, - LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2)) + manager = LoRAModelManager(model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, + max_cpu_loras=2, + max_loras=2), + device=device) model = manager.model assert isinstance(model.get_submodule("gate_up_proj"), diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 00f8e26d1041f..e394c33b3f9ea 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -7,9 +7,10 @@ class DummyLoRAManager: - def __init__(self): + def __init__(self, device: torch.device = "cuda:0"): super().__init__() self._loras: Dict[str, LoRALayerWeights] = {} + self._device = device def set_module_lora(self, module_name: str, lora: LoRALayerWeights): self._loras[module_name] = lora @@ -28,16 +29,16 @@ def init_random_lora(self, lora_alpha=1, lora_a=torch.rand([weight.shape[1], rank], dtype=weight.dtype, - device="cuda"), + device=self._device), lora_b=torch.rand([rank, weight.shape[0]], dtype=weight.dtype, - device="cuda"), + device=self._device), ) if generate_embeddings_tensor: lora.embeddings_tensor = torch.rand(5, generate_embeddings_tensor, dtype=weight.dtype, - device="cuda") + device=self._device) self.set_module_lora(module_name, lora) return lora diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 81e274612b73b..eafc3a43a2846 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -301,6 +301,7 @@ def __init__( max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, + device: torch.device, ): """Create a LoRAModelManager and adapter for a given model. @@ -314,6 +315,7 @@ def __init__( lora_config: the LoRA configuration. """ self.lora_config = lora_config + self.device = device self.max_num_seqs = max_num_seqs assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 @@ -322,7 +324,7 @@ def __init__( self.long_lora_context: Optional[LongContextLoRAContext] = None self.punica_wrapper = PunicaWrapper(max_num_batched_tokens, max_batches=self.max_num_seqs, - device="cuda") + device=self.device) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} @@ -653,16 +655,11 @@ def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], class LRUCacheLoRAModelManager(LoRAModelManager): """A model manager that manages multiple LoRAs with LRU cache.""" - def __init__( - self, - model: nn.Module, - max_num_seqs: int, - max_num_batched_tokens: int, - vocab_size: int, - lora_config: LoRAConfig, - ): + def __init__(self, model: nn.Module, max_num_seqs: int, + max_num_batched_tokens: int, vocab_size: int, + lora_config: LoRAConfig, device: torch.device): super().__init__(model, max_num_seqs, max_num_batched_tokens, - vocab_size, lora_config) + vocab_size, lora_config, device) self._registered_adapters: LoRALRUCache = LoRALRUCache( self.capacity, self.deactivate_adapter) self._active_adapters: LoRALRUCache = LoRALRUCache( @@ -732,6 +729,7 @@ def create_lora_manager( max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, + device: torch.device, lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" @@ -743,5 +741,6 @@ def create_lora_manager( max_num_batched_tokens=max_num_batched_tokens, vocab_size=vocab_size, lora_config=lora_config, + device=device, **kwargs) return lora_manager diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 5033ce4126929..082041f390750 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -62,6 +62,7 @@ def convert_mapping( max_loras: int, vocab_size: int, extra_vocab_size: int, + device: torch.device, long_lora_context: Optional["LongContextLoRAContext"] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], List[int]]: @@ -104,7 +105,7 @@ def convert_mapping( long_lora_offsets: Optional[torch.Tensor] = None if long_lora_context: long_lora_offsets = torch.zeros(len(index_mapping_indices), - device="cuda", + device=device, dtype=torch.long) prompt_mapping: List[int] = [ lora_index_to_id.index(x) if x > 0 else -1 @@ -131,10 +132,10 @@ def convert_mapping( if long_lora_context: assert long_lora_offsets is not None indices_list.append(long_lora_offsets) - indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + indices = torch.tensor(indices_list, dtype=torch.long, device=device) prompt_mapping_tensor = torch.tensor(prompt_mapping, - device="cuda", - dtype=torch.long) + dtype=torch.long, + device=device) embeddings_indices = torch.stack([ indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size), @@ -145,7 +146,7 @@ def convert_mapping( sampler_indices_padded = sampler_indices.clone() sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 sampler_indices_padded = torch.arange( - 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + ( + 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) long_lora_indices = None long_lora_indices_len: Optional[int] = None @@ -183,7 +184,7 @@ class PunicaWrapper: """ def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: str): + device: Union[torch.device, str]): self._token_lora_indices = torch.empty(max_num_batched_tokens, dtype=torch.long, device=device) @@ -215,6 +216,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self._lora_indices_per_batch = torch.empty(max_batches, dtype=torch.long, device=device) + self.device: torch.device = device self.max_length: int = 0 self.token_nums: int = 0 self.batch_size: int = -1 @@ -263,6 +265,7 @@ def _update_base_metadata( max_loras, vocab_size, extra_vocab_size, + self.device, long_lora_context, ) self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 724c308a07a27..93a5e27621912 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -73,6 +73,7 @@ def create_lora_manager( max_num_batched_tokens=self.max_num_batched_tokens, vocab_size=self.vocab_size, lora_config=self.lora_config, + device=self.device, lora_manager_cls=self._manager_cls, ) self._adapter_manager = lora_manager @@ -176,6 +177,7 @@ def create_lora_manager( max_num_seqs=self.max_num_seqs, vocab_size=self.vocab_size, lora_config=self.lora_config, + device=self.device, max_num_batched_tokens=self.max_num_batched_tokens, ) self._adapter_manager = lora_manager