From a115ac46b5be22289dec975c2c06653b22cd6315 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 1 Jan 2025 23:44:42 +0800 Subject: [PATCH] [VLM] Move supported limits and max tokens to merged multi-modal processor (#11669) Signed-off-by: DarkLight1337 Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> --- .../mm_processor_kwargs/test_phi3v.py | 39 +----- .../mm_processor_kwargs/test_qwen2_vl.py | 36 +----- tests/multimodal/test_processing.py | 14 ++- vllm/inputs/registry.py | 8 +- vllm/model_executor/models/aria.py | 75 ++++++------ vllm/model_executor/models/blip2.py | 19 ++- vllm/model_executor/models/chameleon.py | 35 +++--- vllm/model_executor/models/fuyu.py | 105 ++++++++--------- vllm/model_executor/models/llava.py | 8 +- vllm/model_executor/models/phi3v.py | 45 +++---- vllm/model_executor/models/qwen2_audio.py | 42 +++++-- vllm/model_executor/models/qwen2_vl.py | 75 ++++++------ vllm/model_executor/models/ultravox.py | 26 ++-- vllm/multimodal/parse.py | 47 ++------ vllm/multimodal/processing.py | 111 ++++++++++++++++-- vllm/multimodal/registry.py | 5 + 16 files changed, 340 insertions(+), 350 deletions(-) diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py index f95cee277f4e6..3edf96d11106d 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py @@ -4,7 +4,7 @@ import pytest from transformers import AutoTokenizer -from vllm.inputs import InputContext, InputProcessingContext +from vllm.inputs import InputProcessingContext from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from .....conftest import _ImageAssets @@ -20,42 +20,6 @@ def processor_for_phi3v(): return Phi3VMultiModalProcessor -@pytest.fixture() -def get_max_phi3v_image_tokens(): - from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens - return get_max_phi3v_image_tokens - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("num_crops,expected_max_tokens", [ - (4, 781), - (16, 2653), -]) -def test_max_tokens_override(get_max_phi3v_image_tokens, model: str, - num_crops: int, expected_max_tokens: int): - """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" - # NOTE: mm_processor_kwargs on the context in this test is unused, since - # this is testing the mapper directly. In practice, the processor kwargs - # are wrapped in a closure when calling the max tokens func. We explicitly - # do NOT use the mm_processor_kwargs in the model context here to ensure - # that the max image tokens implementation is referencing a mix of the - # kwargs to the function and the original mm_processor_kwargs in case - # values are somehow updated and end up in a bad state. - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - trust_remote_code=True, - mm_processor_kwargs=None, - ) - - actual_max_tokens = get_max_phi3v_image_tokens( - InputContext(ctx.model_config), - num_crops=num_crops, - ) - - assert expected_max_tokens == actual_max_tokens - - @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "num_crops,expected_toks_per_img", @@ -77,6 +41,7 @@ def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets, model_name=model, tokenizer_name=model, trust_remote_code=True, + limit_mm_per_prompt={"image": num_imgs}, ) tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) ctx = InputProcessingContext(ctx.model_config, tokenizer) diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py index 5897c04c89e19..1f0b482666723 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py @@ -3,7 +3,7 @@ import pytest from transformers import AutoTokenizer -from vllm.inputs import InputContext, InputProcessingContext +from vllm.inputs import InputProcessingContext from .....conftest import _ImageAssets from ....utils import build_model_context @@ -22,39 +22,6 @@ def processor_for_qwen2_vl(): return Qwen2VLMultiModalProcessor -@pytest.fixture() -def get_max_qwen2_vl_image_tokens(): - from vllm.model_executor.models.qwen2_vl import ( - get_max_qwen2_vl_image_tokens) - return get_max_qwen2_vl_image_tokens - - -@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [ - ({}, 16384), - ({ - MIN_PIXELS: 64**2, - MAX_PIXELS: 512**2 - }, 324), -]) -@pytest.mark.parametrize("model", [MODEL]) -def test_qwen2_vl_max_image_tokens( - get_max_qwen2_vl_image_tokens, - model: str, - mm_processor_kwargs: Dict[str, Any], - expected_max_tokens: int, -): - """Ensure that the max token calc handles min/max pixels properly.""" - ctx = build_model_context( - model_name=model, - tokenizer_name=model, - mm_processor_kwargs=None, - ) - - actual_max_tokens = get_max_qwen2_vl_image_tokens( - InputContext(ctx.model_config), **mm_processor_kwargs) - assert actual_max_tokens == expected_max_tokens - - @pytest.mark.parametrize( "mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [ ({}, 1426, (5704, 1176)), @@ -82,6 +49,7 @@ def test_processor_override( model_name=model, tokenizer_name=model, mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, ) tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) ctx = InputProcessingContext(ctx.model_config, tokenizer) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 1850ca46ccc8f..9573351b4dff1 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -538,6 +538,11 @@ def _test_processing_cache_correctness( else: hf_overrides = {} + limit_mm_per_prompt = { + modality: 3 if supports_multi else 1 + for modality, supports_multi in modalities.items() + } + model_config = ModelConfig( model_id, task="auto", @@ -548,6 +553,7 @@ def _test_processing_cache_correctness( dtype="float16", revision=None, hf_overrides=hf_overrides, + limit_mm_per_prompt=limit_mm_per_prompt, ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) @@ -580,18 +586,14 @@ def _test_processing_cache_correctness( min_wh=128, max_wh=256), "audio": - partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000), - } - input_max_count = { - modality: 3 if supports_multi else 1 - for modality, supports_multi in modalities.items() + partial(_rand_audio, rng, min_len=512, max_len=1024, sr=16000), } for batch_idx in range(num_batches): mm_data = { k: [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) - for _ in range(rng.randint(input_max_count[k]))] + for _ in range(rng.randint(limit_mm_per_prompt[k]))] for k in modalities } diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 46346b08e99c2..090347706ca93 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -331,13 +331,7 @@ def dummy_data_for_profiling( trust_remote_code=model_config.trust_remote_code, ) processor = mm_registry.create_processor(model_config, tokenizer) - - mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - mm_max_tokens = mm_registry.get_max_tokens_by_modality( - model_config) - - dummy_data = processor.get_dummy_data(seq_len, mm_counts, - mm_max_tokens) + dummy_data = processor.get_dummy_data(seq_len) else: model_cls, _ = get_model_architecture(model_config) if is_encoder_data: diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 4ad6e859f4d93..4f0d679bd6c28 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,5 +1,5 @@ -from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, - Union) +from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -9,7 +9,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_rank -from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -87,8 +86,8 @@ def __init__( def forward( self, pixel_values: torch.Tensor, - pixel_mask: Optional[torch.BoolTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]: + pixel_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: patch_attention_mask = self._create_patch_attention_mask(pixel_mask) vit_oup = self.vision_model( @@ -100,7 +99,8 @@ def forward( return vit_oup, image_atts - def _create_patch_attention_mask(self, pixel_mask): + def _create_patch_attention_mask( + self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor: if pixel_mask is None: return None @@ -115,7 +115,8 @@ def _create_patch_attention_mask(self, pixel_mask): ) return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() - def _create_image_attention_mask(self, patch_attention_mask): + def _create_image_attention_mask( + self, patch_attention_mask: torch.Tensor) -> torch.Tensor: if patch_attention_mask is None: return None @@ -125,13 +126,13 @@ def _create_image_attention_mask(self, patch_attention_mask): class FFN(nn.Module): - def __init__(self, embed_dim, ff_dim, output_dim): + def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None: super().__init__() self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False) self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False) self.act = get_act_fn("gelu_new") - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states, _ = self.linear_in(hidden_states) hidden_states = self.act(hidden_states) hidden_states, _ = self.linear_out(hidden_states) @@ -140,7 +141,7 @@ def forward(self, hidden_states): class CrossAttention(nn.Module): - def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): + def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None: super().__init__() self.num_heads = num_heads self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) @@ -149,12 +150,16 @@ def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) self.linear = nn.Linear(embed_dim, embed_dim) - self.dropout = nn.Dropout(drop_out_rate) self.layer_norm = nn.LayerNorm(embed_dim) self.ln_kv = nn.LayerNorm(kv_dim) - def forward(self, x, hidden_states, attn_mask=None, add_residual=False): + def forward( + self, + x: torch.Tensor, + hidden_states: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: normed_hidden_states = self.layer_norm(hidden_states) query = self.q_proj(normed_hidden_states).permute(1, 0, 2) @@ -169,11 +174,7 @@ def forward(self, x, hidden_states, attn_mask=None, add_residual=False): attn_output = attn_output.permute(1, 0, 2) - if add_residual: - attn_output = hidden_states + self.dropout( - self.linear(attn_output)) - else: - attn_output = self.dropout(self.linear(attn_output)) + attn_output = self.linear(attn_output) return attn_output @@ -201,14 +202,14 @@ class AriaProjector(nn.Module): def __init__( self, - patch_to_query_dict, - embed_dim, - num_heads, - kv_dim, - ff_dim, - output_dim, - norm_layer=nn.LayerNorm, - ): + patch_to_query_dict: dict[int, int], + embed_dim: int, + num_heads: int, + kv_dim: int, + ff_dim: int, + output_dim: int, + norm_layer: Callable[[int], nn.Module] = nn.LayerNorm, + ) -> None: super().__init__() self.patch_to_query_dict = patch_to_query_dict self.embed_dim = embed_dim @@ -224,7 +225,11 @@ def __init__( self.ln_ffn = norm_layer(embed_dim) self.ffn = FFN(embed_dim, ff_dim, output_dim) - def forward(self, x, attn_mask=None): + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: bs = x.shape[0] queries = self.query.unsqueeze(0).repeat(bs, 1, 1) @@ -442,12 +447,17 @@ def build_mm_projector(config: PretrainedConfig): ) -def get_max_aria_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config() - return max(hf_config.projector_patch_to_query_dict.values()) +class AriaMultiModalProcessor(BaseMultiModalProcessor): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + def _get_num_image_tokens(self) -> int: + hf_config = self.ctx.get_hf_config() + return max(hf_config.projector_patch_to_query_dict.values()) -class AriaMultiModalProcessor(BaseMultiModalProcessor): + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + return {"image": self._get_num_image_tokens()} def _get_mm_fields_config( self, @@ -468,13 +478,13 @@ def _get_prompt_replacements( hf_config = self.ctx.get_hf_config() image_token_id = hf_config.image_token_index - max_image_tokens = get_max_aria_image_tokens(self.ctx) + num_image_tokens = self._get_num_image_tokens() return [ PromptReplacement( modality="image", target=[image_token_id], - replacement=[image_token_id] * max_image_tokens, + replacement=[image_token_id] * num_image_tokens, ) ] @@ -504,7 +514,6 @@ def _get_dummy_mm_inputs( ) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens) @MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor) class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): """ diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 50680fadc4aa3..0fe10d8585215 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -9,7 +9,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig -from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -18,7 +17,6 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) -from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, MultiModalDataItems, ProcessorInputs, PromptReplacement) @@ -398,15 +396,17 @@ def forward( return sequence_output -def get_max_blip2_image_tokens(ctx: InputContext): - hf_config = ctx.get_hf_config(Blip2Config) - return hf_config.num_query_tokens +class Blip2MultiModalProcessor(BaseMultiModalProcessor): + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} -class Blip2MultiModalProcessor(BaseMultiModalProcessor): + def _get_num_image_tokens(self) -> int: + hf_config = self.ctx.get_hf_config(Blip2Config) + return hf_config.num_query_tokens - def _get_data_parser(self) -> MultiModalDataParser: - return MultiModalDataParser(max_mm_counts={"image": 1}) + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + return {"image": self._get_num_image_tokens()} def _get_hf_processor(self) -> Blip2Processor: return self.ctx.get_hf_processor(Blip2Processor) @@ -427,7 +427,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - max_image_tokens = get_max_blip2_image_tokens(self.ctx) + max_image_tokens = self._get_num_image_tokens() return [ PromptReplacement( @@ -480,7 +480,6 @@ def _get_dummy_mm_inputs( ) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index c731934e792fc..0bd0194243ceb 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -11,7 +11,6 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.inputs import InputContext from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -31,7 +30,6 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) -from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, MultiModalDataItems, ProcessorInputs, PromptReplacement) @@ -43,11 +41,6 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, merge_multimodal_embeddings) -# These configs are not part of the model config but the preprocessor -# and processor files, so we hardcode them in the model file for now. -CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 -CHAMELEON_IMAGE_SEQ_LENGTH = 1024 - class ChameleonImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -55,14 +48,17 @@ class ChameleonImagePixelInputs(TypedDict): """Shape: `(batch_size * num_images, num_channels, height, width)`""" -def get_max_chameleon_image_tokens(ctx: InputContext): - return CHAMELEON_IMAGE_SEQ_LENGTH +class ChameleonMultiModalProcessor(BaseMultiModalProcessor): + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} -class ChameleonMultiModalProcessor(BaseMultiModalProcessor): + def _get_num_image_tokens(self) -> int: + processor = self._get_hf_processor() + return processor.image_seq_length - def _get_data_parser(self) -> MultiModalDataParser: - return MultiModalDataParser(max_mm_counts={"image": 1}) + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + return {"image": self._get_num_image_tokens()} def _get_hf_processor(self) -> ChameleonProcessor: return self.ctx.get_hf_processor(ChameleonProcessor) @@ -88,7 +84,7 @@ def _get_prompt_replacements( target="", replacement="".join([ processor.image_start_token, - processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH, + processor.image_token * self._get_num_image_tokens(), processor.image_end_token, ]), ) @@ -98,12 +94,15 @@ def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], ) -> ProcessorInputs: + config = self.ctx.get_hf_config(ChameleonConfig) + + width = height = config.vq_config.resolution num_images = mm_counts.get("image", 0) mm_data = { "image": - self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH, - height=CHAMELEON_CROP_SIZE_HEIGHT, + self._get_dummy_images(width=width, + height=height, num_images=num_images) } @@ -902,7 +901,6 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) @MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor) class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): @@ -931,9 +929,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model.make_empty_intermediate_tensors) def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT, - CHAMELEON_CROP_SIZE_WIDTH) + vq_config: ChameleonVQVAEConfig = self.config.vq_config + expected_dims = (3, vq_config.resolution, vq_config.resolution) actual_dims = tuple(data.shape[1:]) if actual_dims != expected_dims: diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 0a48fa3fe11c0..7fb8c5d1ab09c 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -25,7 +25,6 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import InputContext from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.models.persimmon import PersimmonForCausalLM @@ -34,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) -from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser +from vllm.multimodal.parse import ImageProcessorItems, ImageSize from vllm.multimodal.processing import (BaseMultiModalProcessor, MultiModalDataItems, ProcessorInputs, PromptReplacement) @@ -48,9 +47,6 @@ _IMAGE_TOKEN_ID = 71011 _NEWLINE_TOKEN_ID = 71019 -MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080 -MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920 - class FuyuImagePatchInputs(TypedDict): type: Literal["image_patches"] @@ -67,43 +63,49 @@ class FuyuImagePatchInputs(TypedDict): """ -def _get_fuyu_num_image_tokens( - image_height: int, - image_width: int, -) -> Tuple[int, int]: - """ - Calculate the number of image tokens needed for a given image size. - - The expected Fuyu image prompts can be expressed as: - - .. code-block:: - (image_token * ncols + newline_token) * nrows - - Args: - image_size: Tuple[int, int] - `(width, height)` of the image - - Returns: - ncols: int - number of image tokens in `x` direction - nrows: int - number of image tokens in `y` direction - """ - ncols = math.ceil(image_width / 30) - nrows = math.ceil(image_height / 30) - return ncols, nrows - +class FuyuMultiModalProcessor(BaseMultiModalProcessor): -def get_max_fuyu_image_tokens(ctx: InputContext): - ncols, nrows = _get_fuyu_num_image_tokens( - image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, - image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - ) + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} - return (ncols + 1) * nrows + def _get_image_target_size(self) -> ImageSize: + processor = self._get_hf_processor() + image_processor: FuyuImageProcessor = processor.image_processor + target_size = image_processor.size + return ImageSize(width=target_size["width"], + height=target_size["height"]) -class FuyuMultiModalProcessor(BaseMultiModalProcessor): + def _get_image_grid_size( + self, + *, + image_width: int, + image_height: int, + ) -> tuple[int, int]: + target_width, target_height = self._get_image_target_size() + + if not (image_width <= target_width and image_height <= target_height): + height_scale_factor = target_height / image_height + width_scale_factor = target_width / image_width + optimal_scale_factor = min(height_scale_factor, width_scale_factor) + + image_height = int(image_height * optimal_scale_factor) + image_width = int(image_width * optimal_scale_factor) + + ncols = math.ceil(image_width / 30) + nrows = math.ceil(image_height / 30) + return ncols, nrows + + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + target_width, target_height = self._get_image_target_size() + + max_ncols, max_nrows = self._get_image_grid_size( + image_width=target_width, + image_height=target_height, + ) + max_image_tokens = (max_ncols + 1) * max_nrows - def _get_data_parser(self) -> MultiModalDataParser: - return MultiModalDataParser(max_mm_counts={"image": 1}) + return {"image": max_image_tokens} def _get_hf_processor(self) -> FuyuProcessor: return self.ctx.get_hf_processor(FuyuProcessor) @@ -166,28 +168,13 @@ def _get_prompt_replacements( eot_token_id = tokenizer.bos_token_id assert isinstance(eot_token_id, int) - hf_processor = self._get_hf_processor() - image_processor: FuyuImageProcessor = hf_processor.image_processor - target_size = image_processor.size - target_height, target_width = (target_size["height"], - target_size["width"]) - def get_replacement_fuyu(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - width, height = image_size.width, image_size.height - if not (width <= target_width and height <= target_height): - height_scale_factor = target_height / height - width_scale_factor = target_width / width - optimal_scale_factor = min(height_scale_factor, - width_scale_factor) - - height = int(height * optimal_scale_factor) - width = int(width * optimal_scale_factor) - - ncols, nrows = _get_fuyu_num_image_tokens( - image_width=width, - image_height=height, + + ncols, nrows = self._get_image_grid_size( + image_width=image_size.width, + image_height=image_size.height, ) return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows + @@ -225,12 +212,13 @@ def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], ) -> ProcessorInputs: + target_width, target_height = self._get_image_target_size() num_images = mm_counts.get("image", 0) mm_data = { "image": - self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + self._get_dummy_images(width=target_width, + height=target_height, num_images=num_images) } @@ -240,7 +228,6 @@ def _get_dummy_mm_inputs( ) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 34dc7fa31ce6f..808e61edb6fb4 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -119,6 +119,12 @@ def get_max_llava_image_tokens(ctx: InputContext): class LlavaMultiModalProcessor(BaseMultiModalProcessor): + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + return {"image": get_max_llava_image_tokens(self.ctx)} + def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor)) @@ -324,7 +330,6 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): # BitandBytes specific attributes @@ -649,7 +654,6 @@ def get_replacement_mantis(item_idx: int): # To use this model, please use # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor) class MantisForConditionalGeneration(LlavaForConditionalGeneration): pass diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 15362db6cdfbf..d855e7d2d36f8 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -23,7 +23,6 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -306,24 +305,31 @@ def add_image_newline(self, image_features_hd): return image_features_hd_newline -def get_max_phi3v_image_tokens( - ctx: InputContext, - *, - num_crops: Optional[int] = None, -) -> int: - hf_processor_mm_kwargs = {} - if num_crops: - hf_processor_mm_kwargs["num_crops"] = num_crops +class Phi3VMultiModalProcessor(BaseMultiModalProcessor): - processor = ctx.get_hf_processor(**hf_processor_mm_kwargs) + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} - return processor.calc_num_image_tokens_from_image_size( - width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, - ) + def _get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + processor = self._get_hf_processor() + + return processor.calc_num_image_tokens_from_image_size( # type: ignore + width=image_width, + height=image_height, + ) + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + max_image_tokens = self._get_num_image_tokens( + image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + ) -class Phi3VMultiModalProcessor(BaseMultiModalProcessor): + return {"image": max_image_tokens} def _get_hf_processor( self, @@ -332,6 +338,7 @@ def _get_hf_processor( ) -> ProcessorMixin: if num_crops is not None: return self.ctx.get_hf_processor(num_crops=num_crops) + return self.ctx.get_hf_processor() def _call_hf_processor( @@ -375,7 +382,6 @@ def _get_prompt_replacements( ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() image_tokens: list[str] = hf_processor.img_tokens # type: ignore - image_processor = hf_processor.image_processor # type: ignore tokenizer = self._get_tokenizer() bos_token_id = tokenizer.bos_token_id @@ -385,9 +391,9 @@ def get_replacement_phi3v(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - num_tokens = image_processor.calc_num_image_tokens_from_image_size( - width=image_size.width, - height=image_size.height, + num_tokens = self._get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, ) return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id] @@ -467,7 +473,6 @@ def apply( return result -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor) class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index de55bc6bcc123..d050fd060353a 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -33,13 +33,12 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import InputContext from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) -from vllm.multimodal.parse import MultiModalDataParser +from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, MultiModalDataItems, ProcessorInputs, PromptReplacement) @@ -80,14 +79,17 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): return feat_lengths, output_lengths -def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int: - hf_config = ctx.get_hf_config(Qwen2AudioConfig) - max_source_position = hf_config.audio_config.max_source_positions - output_lengths = (max_source_position - 2) // 2 + 1 - return output_lengths +class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None} -class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) + max_source_positions = hf_config.audio_config.max_source_positions + max_output_lengths = (max_source_positions - 2) // 2 + 1 + + return {"audio": max_output_lengths} def _get_hf_processor( self, @@ -157,11 +159,21 @@ def _get_prompt_replacements( audio_output_lengths = [] else: assert isinstance(feature_attention_mask, torch.Tensor) - _, audio_output_lengths = _get_feat_extract_output_lengths( + _, audio_output_lens = _get_feat_extract_output_lengths( feature_attention_mask.sum(-1)) + audio_output_lengths = audio_output_lens.tolist() + def get_replacement_qwen2_audio(item_idx: int): - return [placeholder] * audio_output_lengths[item_idx] + num_placeholders = audio_output_lengths[item_idx] + if num_placeholders == 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio = audios.get(item_idx) + raise ValueError( + f"The audio {audio} (len={len(audio)}) is too short " + "to be represented inside the model") + + return [placeholder] * num_placeholders return [ PromptReplacement( @@ -171,6 +183,14 @@ def get_replacement_qwen2_audio(item_idx: int): ) ] + def _always_apply_prompt_replacements(self) -> bool: + # HF never applies prompt replacements, so we have to do it ourselves + # _find_placeholders may incorrectly think that HF has already performed + # processing for multi-audio input when the input audios are short + # (the corresponding placeholders may take up fewer tokens than + # the number of audio items) + return True + def _get_dummy_mm_inputs( self, mm_counts: Mapping[str, int], @@ -192,8 +212,6 @@ def _get_dummy_mm_inputs( ) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_max_qwen2_audio_audio_tokens) @MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor) class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0df101b3dcce4..26b6d768ad4f6 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -40,7 +40,6 @@ from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils -from vllm.inputs import InputContext from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU @@ -650,8 +649,9 @@ def _get_vision_info( width: int, min_pixels: int, max_pixels: int, + *, do_resize: bool = True, - data_type_key: str = "image", + modality: str = "image", mm_count: int = 1, ): """Get information (resized height / width and number of vision tokens) @@ -671,11 +671,12 @@ def _get_vision_info( else: resized_height, resized_width = height, width - if data_type_key == "image": + if modality == "image": grid_t = mm_count - else: - assert data_type_key == "video" + elif modality == "video": grid_t = max(mm_count // temporal_patch_size, 1) + else: + raise ValueError(f"Modality {modality} is not supported") grid_h = resized_height // patch_size grid_w = resized_width // patch_size @@ -691,41 +692,11 @@ def _get_image_processor(hf_processor: Qwen2VLProcessor): return image_processor -def get_max_qwen2_vl_mm_tokens(ctx: InputContext, - data_type_key: str, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None) -> int: - hf_config = ctx.get_hf_config(Qwen2VLConfig) - vision_config = hf_config.vision_config - - hf_processor = ctx.get_hf_processor(Qwen2VLProcessor) - image_processor = _get_image_processor(hf_processor) - - _, _, max_llm_image_tokens = _get_vision_info( - vision_config, - height=9999999, - width=9999999, - min_pixels=min_pixels or image_processor.min_pixels, - max_pixels=max_pixels or image_processor.max_pixels, - data_type_key=data_type_key, - ) - return max_llm_image_tokens - - -get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens, - data_type_key="image") -get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens, - data_type_key="video") - - class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], dict[str, torch.Tensor]]): def __init__(self, data: dict, modality: str) -> None: - super().__init__(data) - - self.modality = modality + super().__init__(data, modality) grid_thw = data[f"{modality}_grid_thw"] slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist() @@ -734,9 +705,6 @@ def __init__(self, data: dict, modality: str) -> None: for i in range(len(grid_thw)) ] - def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r})") - def get_count(self) -> int: return len(self.data[f"{self.modality}_grid_thw"]) @@ -792,6 +760,32 @@ def _parse_video_data( class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def _get_max_mm_tokens(self, modality: str) -> int: + hf_config = self.ctx.get_hf_config(Qwen2VLConfig) + vision_config = hf_config.vision_config + + hf_processor = self._get_hf_processor() + image_processor = _get_image_processor(hf_processor) + + _, _, max_llm_image_tokens = _get_vision_info( + vision_config, + height=9999999, + width=9999999, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + modality=modality, + ) + return max_llm_image_tokens + + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + return { + "image": self._get_max_mm_tokens("image"), + "video": self._get_max_mm_tokens("video"), + } + def _get_data_parser(self) -> MultiModalDataParser: return Qwen2MultiModalDataParser() @@ -908,9 +902,6 @@ def _get_dummy_mm_inputs( ) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "video", get_max_qwen2_vl_video_tokens) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 54be7fed3f2be..0b83684c9bac5 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,7 +2,7 @@ """PyTorch Ultravox model.""" import math -from functools import cached_property, lru_cache +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -17,7 +17,6 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import InputContext from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -58,22 +57,17 @@ class UltravoxAudioEmbeddingInputs(TypedDict): UltravoxAudioEmbeddingInputs] -@lru_cache -def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor: - return WhisperFeatureExtractor.from_pretrained(model_id) - - -def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor: - hf_config = ctx.get_hf_config(UltravoxConfig) - return cached_feature_extractor(hf_config.audio_model_id) - +class UltravoxMultiModalProcessor(BaseMultiModalProcessor): -def get_ultravox_max_audio_tokens(ctx: InputContext): - feature_extractor = whisper_feature_extractor(ctx) - return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": None} + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + feature_extractor = self._get_feature_extractor() + max_audio_tokens = math.ceil(feature_extractor.chunk_length * + _AUDIO_TOKENS_PER_SECOND) -class UltravoxMultiModalProcessor(BaseMultiModalProcessor): + return {"audio": max_audio_tokens} def _get_hf_processor( self, @@ -322,8 +316,6 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", get_ultravox_max_audio_tokens) @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index da111e999ebb8..4e1b78ab2c59d 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -21,10 +21,15 @@ class ModalityDataItems(ABC, Generic[_T, _I]): - def __init__(self, data: _T) -> None: + def __init__(self, data: _T, modality: str) -> None: super().__init__() self.data = data + self.modality = modality + + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r}, " + f"len={len(self)})") def __len__(self) -> int: return self.get_count() @@ -64,14 +69,6 @@ def get_passthrough_data(self) -> Mapping[str, object]: class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): - def __init__(self, data: Sequence[_T], modality: str) -> None: - super().__init__(data) - - self.modality = modality - - def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r})") - def get_count(self) -> int: return len(self.data) @@ -87,14 +84,6 @@ def get_passthrough_data(self) -> Mapping[str, object]: class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]): - def __init__(self, data: NestedTensors, modality: str) -> None: - super().__init__(data) - - self.modality = modality - - def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r})") - def get_count(self) -> int: return len(self.data) @@ -222,22 +211,13 @@ class MultiModalDataParser: Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`. Args: - max_mm_counts (Mapping[str, int]): The maximum allowed number of items - belonging to each modality. This effectively sets a hard limit over - `--limit-mm-per-prompt`. target_sr (float, optional): Enables automatic resampling of audio items to the model's expected sampling rate. """ - def __init__( - self, - *, - max_mm_counts: Mapping[str, int] = {}, - target_sr: Optional[float] = None, - ) -> None: + def __init__(self, *, target_sr: Optional[float] = None) -> None: super().__init__() - self.max_mm_counts = max_mm_counts self.target_sr = target_sr def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]: @@ -345,7 +325,6 @@ def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems: - max_mm_counts = self.max_mm_counts subparsers = self._get_subparsers() mm_items = MultiModalDataItems() @@ -353,16 +332,6 @@ def parse_mm_data(self, if k not in subparsers: raise ValueError(f"Unsupported modality: {k}") - modality_items = subparsers[k](v) - - if k in max_mm_counts: - max_count = max_mm_counts[k] - if len(modality_items) > max_count: - raise ValueError( - f"This model supports at most {max_count} {k} items " - f"per prompt, but {len(modality_items)} {k} items " - "were given or set as its limit_mm_per_prompt.") - - mm_items[k] = modality_items + mm_items[k] = subparsers[k](v) return mm_items diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 7712c3bcebe20..76475ddda81f4 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -624,6 +624,29 @@ def __call__( ) -> MultiModalInputsV2: return self.apply(prompt, mm_data, hf_processor_mm_kwargs) + @abstractmethod + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + """ + Return the maximum supported number of items for each modality. + + A value of `None` means unlimited number of items. + + Omitting a modality from the returned dictionary means that + it is not supported at all. + """ + raise NotImplementedError + + @abstractmethod + def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + """ + Get the maximum possible number of tokens per data item + for each modality. + + The dictionary returned by this method should have the same + keys as that returned by :meth:`get_supported_mm_limits`. + """ + raise NotImplementedError + def _get_data_parser(self) -> MultiModalDataParser: """ Construct a data parser to preprocess multi-modal data items @@ -653,7 +676,18 @@ def _to_mm_items( before passing them to :meth:`_get_hf_mm_data`. """ parser = self._get_data_parser() - return parser.parse_mm_data(mm_data) + mm_items = parser.parse_mm_data(mm_data) + + mm_limits = self.ctx.get_mm_config().limit_per_prompt + for modality, items in mm_items.items(): + limit = mm_limits.get(modality, 1) + if len(items) > limit: + raise ValueError( + f"You set {modality}={limit} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but passed {len(items)} " + f"{modality} items in the same prompt.") + + return mm_items @abstractmethod def _get_mm_fields_config( @@ -901,6 +935,17 @@ def _bind_prompt_replacements( return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls] + def _always_apply_prompt_replacements(self) -> bool: + """ + A flag which can be overridden so that + :meth:`_apply_prompt_replacements` is always called even if we + detect that HF has performed processing via :meth:`_find_placeholders`. + + This is useful in cases where :meth:`_find_placeholders` cannot be + reliably used to detect whether HF has performed processing or not. + """ + return False + def _apply_prompt_replacements( self, token_ids: list[int], @@ -995,7 +1040,7 @@ def apply( all_placeholders = self._find_placeholders(prompt_repls, prompt_ids, mm_item_counts) - if all_placeholders: + if all_placeholders and not self._always_apply_prompt_replacements(): tokenizer = self._get_tokenizer() prompt_text = _decode(tokenizer, prompt_ids) else: @@ -1009,10 +1054,27 @@ def apply( mm_item_counts, ) - mm_placeholders = { - modality: [item.to_range() for item in items] - for modality, items in full_groupby_modality(all_placeholders) - } + mm_placeholders = dict[str, list[PlaceholderRange]]() + err_suffix = ("This suggests a problem with your implementation of " + "the merged multi-modal processor for this model, " + "particularly in the `_get_prompt_replacements` method.") + + for modality, placeholders in full_groupby_modality(all_placeholders): + if modality not in mm_items: + raise AssertionError( + f"Expected no placeholders for {modality=}, " + f"but found {placeholders=}. Input items: {mm_items}" + f"\n{err_suffix}") + + if len(placeholders) != len(mm_items[modality]): + raise AssertionError( + f"Expected length of {placeholders=} for {modality=} " + f"to equal that of input items: {mm_items[modality]}" + f"\n{err_suffix}") + + mm_placeholders[modality] = [ + item.to_range() for item in placeholders + ] return MultiModalInputsV2( type="multimodal", @@ -1063,15 +1125,38 @@ def _get_dummy_mm_inputs( """ raise NotImplementedError - def get_dummy_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_max_tokens: Mapping[str, int], - ) -> DummyData: + def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]: + mm_limit_per_prompt = self.ctx.get_mm_config().limit_per_prompt + supported_mm_limits = self.get_supported_mm_limits() + + mm_limits = { + modality: mm_limit_per_prompt.get(modality, 1) + for modality in supported_mm_limits + } + + for modality, supported_limit in supported_mm_limits.items(): + limit = mm_limits[modality] + if supported_limit is not None and supported_limit < limit: + raise ValueError( + f"You set {modality}={limit} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but this model only supports " + f"at most {supported_limit} {modality} items.") + + return mm_limits + + def get_dummy_data(self, seq_len: int) -> DummyData: # Avoid circular import from vllm.sequence import SequenceData + mm_counts = self._get_and_validate_dummy_mm_counts() + mm_max_tokens_per_item = self.get_mm_max_tokens_per_item() + if mm_counts.keys() != mm_max_tokens_per_item.keys(): + raise AssertionError( + "The keys returned by `get_supported_mm_limits`" + f"({set(mm_counts.keys())}) should be the same as those " + "returned by `get_mm_max_tokens_per_item` " + f"({set(mm_max_tokens_per_item.keys())})") + processor_inputs = self._get_dummy_mm_inputs(mm_counts) mm_inputs = self.apply( prompt_text=processor_inputs.prompt_text, @@ -1087,7 +1172,7 @@ def get_dummy_data( for modality, placeholders in placeholders_by_modality.items() } expected_placeholders_by_modality = { - modality: mm_max_tokens[modality] + modality: mm_max_tokens_per_item[modality] * mm_counts[modality] for modality in placeholders_by_modality } if total_placeholders_by_modality != expected_placeholders_by_modality: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 3a5e11867ad9e..073d49d7d2009 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -15,6 +15,7 @@ from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .processing import BaseMultiModalProcessor, ProcessingCache +from .utils import cached_get_tokenizer from .video import VideoPlugin if TYPE_CHECKING: @@ -219,6 +220,10 @@ def get_max_tokens_per_item_by_modality( Note: This is currently directly used only in V1. """ + if self.has_processor(model_config): + tokenizer = cached_get_tokenizer(model_config.tokenizer) + processor = self.create_processor(model_config, tokenizer) + return processor.get_mm_max_tokens_per_item() return { key: plugin.get_max_multimodal_tokens(model_config)