From 3d49776bbb25927abf91bb7c5537e0006c199c16 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 29 Sep 2024 14:59:45 +0800 Subject: [PATCH] [Model][LoRA]LoRA support added for MiniCPMV2.5 (#7199) --- tests/lora/conftest.py | 5 ++ tests/lora/test_minicpmv.py | 71 +++++++++++++++ tests/lora/test_minicpmv_tp.py | 95 ++++++++++++++++++++ vllm/lora/models.py | 45 +++++++++- vllm/model_executor/models/minicpmv.py | 94 ++++++++++++++----- vllm/model_executor/models/module_mapping.py | 69 ++++++++++++++ vllm/model_executor/models/utils.py | 22 ++++- vllm/worker/model_runner.py | 8 +- 8 files changed, 378 insertions(+), 31 deletions(-) create mode 100644 tests/lora/test_minicpmv.py create mode 100644 tests/lora/test_minicpmv_tp.py create mode 100644 vllm/model_executor/models/module_mapping.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 4834a9d35a3ee..7f6f60f38b5de 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -194,6 +194,11 @@ def baichuan_zero_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") +@pytest.fixture(scope="session") +def minicpmv_lora_files(): + return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") + + @pytest.fixture(scope="session") def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py new file mode 100644 index 0000000000000..81b8188e638c9 --- /dev/null +++ b/tests/lora/test_minicpmv.py @@ -0,0 +1,71 @@ +from typing import List + +import vllm +from vllm.assets.image import ImageAsset +from vllm.lora.request import LoRARequest + +MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" + +PROMPT_TEMPLATE = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + "(./)\nWhat is in the image?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n") + +IMAGE_ASSETS = [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), +] + +# After fine-tuning with LoRA, all generated content should start begin `A`. +EXPECTED_OUTPUT = [ + "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 + "A pink cherry blossom tree with a blue sky in the background.", +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=5, + stop_token_ids=[128001, 128009], # eos_id, eot_id + ) + + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in IMAGE_ASSETS] + + outputs = llm.generate( + inputs, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_minicpmv_lora(minicpmv_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_num_seqs=2, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True, + ) + + output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output1[i]) + output2 = do_sample(llm, minicpmv_lora_files, lora_id=2) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output2[i]) diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py new file mode 100644 index 0000000000000..ba29e562e58ec --- /dev/null +++ b/tests/lora/test_minicpmv_tp.py @@ -0,0 +1,95 @@ +from typing import List + +import pytest + +import vllm +from vllm.assets.image import ImageAsset +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" + +PROMPT_TEMPLATE = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + "(./)\nWhat is in the image?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n") + +IMAGE_ASSETS = [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), +] + +# After fine-tuning with LoRA, all generated content should start begin `A`. +EXPECTED_OUTPUT = [ + "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 + "A pink cherry blossom tree with a blue sky in the background.", +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=5, + stop_token_ids=[128001, 128009], # eos_id, eot_id + ) + + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in IMAGE_ASSETS] + + outputs = llm.generate( + inputs, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=2, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=2, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) + + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) + + +@multi_gpu_test(num_gpus=4) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=2, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index bc4cab1470f44..1f80c716bc481 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -24,7 +24,9 @@ from vllm.lora.punica import PunicaWrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) -from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.interfaces import (SupportsLoRA, + supports_multimodal) +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer from vllm.utils import is_pin_memory_available @@ -332,6 +334,8 @@ def __init__( self.supported_lora_modules.append("rotary_emb") self.packed_modules_mapping = copy.deepcopy( self.model.packed_modules_mapping) + # Used to indicate whether the model is a multimodal model + self.supports_mm: bool = supports_multimodal(self.model) self.packed_modules: Dict[str, List[str]] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {} # Dict instead of a Set for compatibility with LRUCache. @@ -437,12 +441,22 @@ def _create_lora_modules(self): continue if not self._match_target_modules(module_name): continue + # A temporary approach for multimodal models to support LoRA + # TODO: Remove this restriction + if self._filter_unsupported_mm_module(module_name): + logger.warning( + "Regarding multimodal models, vLLM currently only supports " + "adding LoRA to language model, %s will be ignored.", + module_name, + ) + continue parts = module_name.split(".")[-1] packed_moduled_lst = self.packed_modules_mapping.get(parts, []) new_module = replace_submodule( self.model, module_name, from_layer(module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config)) + # LinearScalingRotaryEmbeddingWithLora is used to handle # long context lora. Register relevant metadata. if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): @@ -460,6 +474,15 @@ def _create_lora_modules(self): module, self.lora_slots, self.lora_config, self.model.config)) + + # In some models, especially multimodal ones, layers with the same + # name may have different types, such as nn.Linear and + # ReplicatedLinear. The nn.Linear layers cannot be replaced with + # LoRA layers, leading to assertion error. The following check + # aims to prevent this error + if self.supports_mm and not isinstance(new_module, + BaseLayerWithLoRA): + continue self.register_module(module_name, new_module) self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. @@ -478,9 +501,10 @@ def create_dummy_lora( """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): - if not self._match_target_modules(module_name) or not isinstance( - module, BaseLayerWithLoRA) or isinstance( - module, LinearScalingRotaryEmbeddingWithLora): + if (not self._match_target_modules(module_name) + or not isinstance(module, BaseLayerWithLoRA) + or isinstance(module, LinearScalingRotaryEmbeddingWithLora) + or self._filter_unsupported_mm_module(module_name)): continue parts = module_name.split(".") if module_name not in self.packed_modules: @@ -541,6 +565,19 @@ def _match_target_modules(self, module_name: str): module_name) or target_module == module_name for target_module in self.supported_lora_modules) + def _filter_unsupported_mm_module(self, module_name: str) -> bool: + """ + Regarding multimodal models, vLLM currently only supports adding LoRA to + language model. LoRA for other modules, such as the vision tower, will + be filtered out. + """ + if self.supports_mm: + prefix = module_name.split(".")[0] + module_mapping: MultiModelKeys = self.model.get_mm_mapping() + return (prefix in module_mapping.connector + or prefix in module_mapping.tower_model) + return False + def _register_packed_modules(self, module_full_name: str) -> None: parts = module_full_name.split(".") module_name = parts[-1] diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 7da7991b4f849..89cdfbcc6afa9 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -36,7 +36,7 @@ from typing_extensions import NotRequired from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -50,7 +50,9 @@ from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.models.utils import LLMWrapper from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -59,10 +61,10 @@ from vllm.sequence import IntermediateTensors, SequenceData from .idefics2_vision_model import Idefics2VisionTransformer +from .interfaces import SupportsLoRA _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", - "llm.model": "llm", } @@ -621,6 +623,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field(language_model="llm", + connector="resampler", + tower_model="vpm") + def init_llm( self, config: PretrainedConfig, @@ -669,9 +679,11 @@ def init_llm( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> nn.Module: - return MiniCPMModel(config, - cache_config=cache_config, - quant_config=quant_config) + + return LLMWrapper(MiniCPMModel(config, + cache_config=cache_config, + quant_config=quant_config), + name="model") def init_vision_module(self) -> nn.Module: # TODO :refactor this vision model @@ -697,6 +709,9 @@ def init_vision_module(self) -> nn.Module: return model + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_tokens(input_ids) + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: with set_default_torch_dtype(torch.float16): resampler = Resampler2( @@ -743,7 +758,34 @@ def is_default_weight_loading(self, name: str) -> bool: return "resampler" in name or "vpm" in name -class MiniCPMV2_5(MiniCPMVBaseModel): +class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] def __init__( self, @@ -751,6 +793,7 @@ def __init__( multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__(config, multimodal_config, cache_config, quant_config) assert self.version == (2, 5) @@ -761,9 +804,10 @@ def init_llm( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> nn.Module: - return LlamaModel(config, - cache_config=cache_config, - quant_config=quant_config) + return LLMWrapper(LlamaModel(config, + cache_config=cache_config, + quant_config=quant_config), + name="model") def init_vision_module(self) -> nn.Module: model = Idefics2VisionTransformer(self.config.vision_config) @@ -843,9 +887,11 @@ def init_llm( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> nn.Module: - return Qwen2Model(config, - cache_config=cache_config, - quant_config=quant_config) + + return LLMWrapper(Qwen2Model(config, + cache_config=cache_config, + quant_config=quant_config), + name="model") def init_vision_module(self) -> nn.Module: # A custom version of SiglipVisionTransformer, won't work with TP @@ -870,7 +916,6 @@ def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: num_heads=embed_dim // 128, kv_dim=vision_dim, ) - return resampler def get_vision_embedding( @@ -934,20 +979,25 @@ def is_default_weight_loading(self, name: str) -> bool: @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) -class MiniCPMV(MiniCPMVBaseModel): +class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): """ Different versions of MiniCPMV use different visual encoders and LLMs, which is not conducive to the current integration logic of LoRA and bitsandbytes in vLLM. Therefore, it is necessary to separate them. """ - - def __new__( - cls, - config: PretrainedConfig, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + def __new__(cls, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): if not hasattr(config, "version"): if config.hidden_size == 2304 and config.query_num == 64: version = (2, 0) diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py new file mode 100644 index 0000000000000..a9102a6073a2f --- /dev/null +++ b/vllm/model_executor/models/module_mapping.py @@ -0,0 +1,69 @@ +# Adapted from +# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py + +from dataclasses import dataclass, field +from typing import List, Union + + +@dataclass +class ModelKeys: + model_type: str = None + + module_list: str = None + + embedding: str = None + + mlp: str = None + + down_proj: str = None + + attention: str = None + + o_proj: str = None + + q_proj: str = None + + k_proj: str = None + + v_proj: str = None + + qkv_proj: str = None + + qk_proj: str = None + + qa_proj: str = None + + qb_proj: str = None + + kva_proj: str = None + + kvb_proj: str = None + + output: str = None + + +@dataclass +class MultiModelKeys(ModelKeys): + language_model: List[str] = field(default_factory=list) + connector: List[str] = field(default_factory=list) + # vision tower and audio tower + tower_model: List[str] = field(default_factory=list) + generator: List[str] = field(default_factory=list) + + @staticmethod + def from_string_field(language_model: Union[str, List[str]] = None, + connector: Union[str, List[str]] = None, + tower_model: Union[str, List[str]] = None, + generator: Union[str, List[str]] = None, + **kwargs) -> 'MultiModelKeys': + + def to_list(value): + if value is None: + return [] + return [value] if isinstance(value, str) else list(value) + + return MultiModelKeys(language_model=to_list(language_model), + connector=to_list(connector), + tower_model=to_list(tower_model), + generator=to_list(generator), + **kwargs) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 38d6a4653ebd6..f6218bad4ef1e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from collections import UserDict -from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, - Union, overload) +from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol, + Tuple, Union, overload) import torch import torch.nn as nn @@ -329,3 +329,21 @@ def make_empty_intermediate_tensors( }) return make_empty_intermediate_tensors + + +class LLMWrapper(nn.Module): + """ + To align with the key names of LoRA trained with PEFT, we need to add an + additional layer to the llm's implementation. + """ + + def __init__(self, llm: nn.Module, name: str) -> None: + super().__init__() + self.model_name = name + setattr(self, name, llm) + + def forward(self, *args, **kwargs) -> Any: + return getattr(self, self.model_name)(*args, **kwargs) + + def embed_tokens(self, *args, **kwargs) -> Any: + return getattr(self, self.model_name).embed_tokens(*args, **kwargs) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4ac67a5fade8f..6e5c4826da3d3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1034,10 +1034,12 @@ def load_model(self) -> None: self.model_memory_usage / float(2**30)) if self.lora_config: - assert supports_lora(self.model), "Model does not support LoRA" - assert not supports_multimodal( + assert supports_lora( self.model - ), "To be tested: Multi-modal model with LoRA settings." + ), f"{self.model.__class__.__name__} does not support LoRA yet." + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs,