Skip to content

Commit

Permalink
[Model][LoRA]LoRA support added for MiniCPMV2.5 (vllm-project#7199)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored Sep 29, 2024
1 parent bc2ef1f commit 3d49776
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 31 deletions.
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def baichuan_zero_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


@pytest.fixture(scope="session")
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
Expand Down
71 changes: 71 additions & 0 deletions tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import List

import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest

MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"

PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")

IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]

# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)

inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]

outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_num_seqs=2,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
)

output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])
95 changes: 95 additions & 0 deletions tests/lora/test_minicpmv_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import List

import pytest

import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest

from ..utils import multi_gpu_test

MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"

PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")

IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]

# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)

inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]

outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)

output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)

for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])


@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
45 changes: 41 additions & 4 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.model_executor.models.interfaces import (SupportsLoRA,
supports_multimodal)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.utils import is_pin_memory_available

Expand Down Expand Up @@ -332,6 +334,8 @@ def __init__(
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = supports_multimodal(self.model)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
# Dict instead of a Set for compatibility with LRUCache.
Expand Down Expand Up @@ -437,12 +441,22 @@ def _create_lora_modules(self):
continue
if not self._match_target_modules(module_name):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
module_name,
)
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))

# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
Expand All @@ -460,6 +474,15 @@ def _create_lora_modules(self):
module, self.lora_slots,
self.lora_config,
self.model.config))

# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if self.supports_mm and not isinstance(new_module,
BaseLayerWithLoRA):
continue
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
Expand All @@ -478,9 +501,10 @@ def create_dummy_lora(
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA) or isinstance(
module, LinearScalingRotaryEmbeddingWithLora):
if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
or self._filter_unsupported_mm_module(module_name)):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
Expand Down Expand Up @@ -541,6 +565,19 @@ def _match_target_modules(self, module_name: str):
module_name) or target_module == module_name
for target_module in self.supported_lora_modules)

def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
return (prefix in module_mapping.connector
or prefix in module_mapping.tower_model)
return False

def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
Expand Down
Loading

0 comments on commit 3d49776

Please sign in to comment.