diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index b32faa699ebf2..75d878217b657 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -586,9 +586,10 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ) processor = processor_factory(ctx, cache=None) + profiler = processor.profiling_info mock_supported_mm_limits = MagicMock(return_value={"image": num_supported}) - processor.get_supported_mm_limits = mock_supported_mm_limits + profiler.get_supported_mm_limits = mock_supported_mm_limits if is_valid: exc_ctx = nullcontext() @@ -596,7 +597,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): exc_ctx = pytest.raises(ValueError, match="this model only supports") with exc_ctx: - processor._get_and_validate_dummy_mm_counts() + profiler.get_mm_limits() @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @@ -723,7 +724,7 @@ def _test_processing_cache_correctness( } mm_counts = {k: len(vs) for k, vs in mm_data.items()} - prompt = baseline_processor._get_dummy_processor_inputs( + prompt = baseline_processor.profiling_info.get_dummy_processor_inputs( model_config.max_model_len, mm_counts, ).prompt_text diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 8f5fd64a90c87..2e649f10c0765 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -24,8 +24,9 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, NestedTensors) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, + MultiModalDataItems, ProcessingMixin, PromptReplacement) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, AriaVisionConfig) @@ -444,18 +445,58 @@ def build_mm_projector(config: PretrainedConfig): ) -class AriaMultiModalProcessor(BaseMultiModalProcessor): +class AriaProcessingMixin(ProcessingMixin): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} + def _get_hf_config(self): + return self.ctx.get_hf_config() + + def _get_vision_config(self) -> AriaVisionConfig: + return self._get_hf_config().vision_config def _get_num_image_tokens(self) -> int: - hf_config = self.ctx.get_hf_config() + hf_config = self._get_hf_config() return max(hf_config.projector_patch_to_query_dict.values()) + +class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"image": self._get_num_image_tokens()} + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + vision_config = self._get_vision_config() + + max_image_size = vision_config.image_size + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } + + hf_processor = self._get_hf_processor() + image_token: str = hf_processor.image_token # type: ignore + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) + + +class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return AriaProfilingInfo(self.ctx) + def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -472,7 +513,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self.ctx.get_hf_config() + hf_config = self._get_hf_config() image_token_id = hf_config.image_token_index num_image_tokens = self._get_num_image_tokens() @@ -485,32 +526,6 @@ def _get_prompt_replacements( ) ] - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - hf_config = self.ctx.get_hf_config() - vision_config: AriaVisionConfig = hf_config.vision_config - - max_image_size = vision_config.image_size - num_images = mm_counts.get("image", 0) - - mm_data = { - "image": - self._get_dummy_images(width=max_image_size, - height=max_image_size, - num_images=num_images) - } - - hf_processor = self._get_hf_processor() - image_token: str = hf_processor.image_token # type: ignore - - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) - @MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor) class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index b3ecb2f22dc19..fd45783f167b4 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn -from transformers import (BatchFeature, Blip2Config, Blip2Processor, - Blip2QFormerConfig, apply_chunking_to_forward) +from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, + apply_chunking_to_forward) from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, VllmConfig @@ -18,8 +18,9 @@ MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, + MultiModalDataItems, ProcessingMixin, PromptReplacement) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from .blip import BlipVisionModel @@ -396,20 +397,52 @@ def forward( return sequence_output -class Blip2MultiModalProcessor(BaseMultiModalProcessor): +class Blip2ProcessingMixin(ProcessingMixin): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": 1} + def _get_hf_config(self): + return self.ctx.get_hf_config(Blip2Config) def _get_num_image_tokens(self) -> int: - hf_config = self.ctx.get_hf_config(Blip2Config) + hf_config = self._get_hf_config() return hf_config.num_query_tokens + +class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"image": self._get_num_image_tokens()} - def _get_hf_processor(self) -> Blip2Processor: - return self.ctx.get_hf_processor(Blip2Processor) + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + hf_config = self._get_hf_config() + vision_config = hf_config.vision_config + + max_image_size = vision_config.image_size + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=max_image_size, + height=max_image_size, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return Blip2ProfilingInfo(self.ctx) def _get_mm_fields_config( self, @@ -427,13 +460,13 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - max_image_tokens = self._get_num_image_tokens() + num_image_tokens = self._get_num_image_tokens() return [ PromptReplacement( modality="image", target="", - replacement="" * max_image_tokens + "", + replacement="" * num_image_tokens + "", ) ] @@ -457,29 +490,6 @@ def apply( return result - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - hf_config = self.ctx.get_hf_config(Blip2Config) - vision_config = hf_config.vision_config - - max_image_size = vision_config.image_size - num_images = mm_counts.get("image", 0) - - mm_data = { - "image": - self._get_dummy_images(width=max_image_size, - height=max_image_size, - num_images=num_images) - } - - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) - @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor) class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 1ad44678a591d..73ed73b61ebf9 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -31,8 +31,9 @@ MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, + MultiModalDataItems, ProcessingMixin, PromptReplacement) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once @@ -48,20 +49,55 @@ class ChameleonImagePixelInputs(TypedDict): """Shape: `(batch_size * num_images, num_channels, height, width)`""" -class ChameleonMultiModalProcessor(BaseMultiModalProcessor): +class ChameleonProcessingMixin(ProcessingMixin): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": 1} + def _get_hf_config(self): + return self.ctx.get_hf_config(ChameleonConfig) + + def _get_hf_processor(self): + return self.ctx.get_hf_processor(ChameleonProcessor) def _get_num_image_tokens(self) -> int: processor = self._get_hf_processor() return processor.image_seq_length + +class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"image": self._get_num_image_tokens()} - def _get_hf_processor(self) -> ChameleonProcessor: - return self.ctx.get_hf_processor(ChameleonProcessor) + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + config = self._get_hf_config() + + width = height = config.vq_config.resolution + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=width, + height=height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="" * num_images, + mm_data=mm_data, + ) + + +class ChameleonMultiModalProcessor(ChameleonProcessingMixin, + BaseMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return ChameleonProfilingInfo(self.ctx) def _get_mm_fields_config( self, @@ -76,7 +112,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - processor = self._get_hf_processor() + processor = self._get_hf_processor(**hf_processor_mm_kwargs) return [ PromptReplacement( @@ -90,28 +126,6 @@ def _get_prompt_replacements( ) ] - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - config = self.ctx.get_hf_config(ChameleonConfig) - - width = height = config.vq_config.resolution - num_images = mm_counts.get("image", 0) - - mm_data = { - "image": - self._get_dummy_images(width=width, - height=height, - num_images=num_images) - } - - return ProcessorInputs( - prompt_text="" * num_images, - mm_data=mm_data, - ) - def apply( self, prompt_text: str, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 7cd58fbc7cf21..c937fcb0978b9 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -35,8 +35,9 @@ NestedTensors, PlaceholderRange) from vllm.multimodal.parse import ImageProcessorItems, ImageSize from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, + MultiModalDataItems, ProcessingMixin, PromptReplacement) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -63,18 +64,16 @@ class FuyuImagePatchInputs(TypedDict): """ -class FuyuMultiModalProcessor(BaseMultiModalProcessor): +class FuyuProcessingMixin(ProcessingMixin): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": 1} + def _get_hf_config(self): + return self.ctx.get_hf_config(FuyuConfig) - def _get_image_target_size(self) -> ImageSize: - processor = self._get_hf_processor() - image_processor: FuyuImageProcessor = processor.image_processor + def _get_hf_processor(self): + return self.ctx.get_hf_processor(FuyuProcessor) - target_size = image_processor.size - return ImageSize(width=target_size["width"], - height=target_size["height"]) + def _get_image_processor(self) -> FuyuImageProcessor: + return self._get_hf_processor().image_processor def _get_image_feature_grid_size( self, @@ -82,7 +81,9 @@ def _get_image_feature_grid_size( image_width: int, image_height: int, ) -> tuple[int, int]: - target_width, target_height = self._get_image_target_size() + image_processor = self._get_image_processor() + target_width = image_processor.size["width"] + target_height = image_processor.size["height"] if not (image_width <= target_width and image_height <= target_height): height_scale_factor = target_height / image_height @@ -96,8 +97,14 @@ def _get_image_feature_grid_size( nrows = math.ceil(image_height / 30) return ncols, nrows + +class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - target_width, target_height = self._get_image_target_size() + target_width, target_height = self._get_image_size_with_most_features() max_ncols, max_nrows = self._get_image_feature_grid_size( image_width=target_width, @@ -107,8 +114,36 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"image": max_image_tokens} - def _get_hf_processor(self) -> FuyuProcessor: - return self.ctx.get_hf_processor(FuyuProcessor) + def _get_image_size_with_most_features(self) -> ImageSize: + image_processor = self._get_image_processor() + return ImageSize(width=image_processor.size["width"], + height=image_processor.size["height"]) + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + target_width, target_height = self._get_image_size_with_most_features() + num_images = mm_counts.get("image", 0) + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text="", + mm_data=mm_data, + ) + + +class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return FuyuProfilingInfo(self.ctx) def _call_hf_processor( self, @@ -161,7 +196,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self.ctx.get_hf_config(FuyuConfig) + hf_config = self._get_hf_config() bos_token_id = hf_config.bos_token_id tokenizer = self._get_tokenizer() @@ -208,26 +243,6 @@ def apply( return result - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - target_width, target_height = self._get_image_target_size() - num_images = mm_counts.get("image", 0) - - mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - return ProcessorInputs( - prompt_text="", - mm_data=mm_data, - ) - @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor) class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index d522378e0bebb..4299af8cd03a2 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from functools import cached_property from typing import (Final, Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) @@ -13,6 +13,7 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig +from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -25,9 +26,10 @@ NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize) -from vllm.multimodal.processing import (InputProcessingContext, +from vllm.multimodal.processing import (BaseMultiModalProcessor, MultiModalDataItems, ProcessingCache, - ProcessorInputs, PromptReplacement) + ProcessingMixin, PromptReplacement) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel @@ -37,7 +39,7 @@ from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import BaseVisionLanguageMultiModalProcessor +from .vision import get_vision_encoder_info class LlavaImagePixelInputs(TypedDict): @@ -94,30 +96,42 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: class LlavaLikeConfig(Protocol): vision_config: Final[PretrainedConfig] + image_token_index: Final[int] vision_feature_select_strategy: Final[str] - vision_feature_layer: Final[Union[int, List[int]]] + vision_feature_layer: Final[Union[int, list[int]]] -class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor): +class LlavaLikeProcessor(Protocol): + image_token: Final[str] + + +class BaseLlavaProcessingMixin(ProcessingMixin, ABC): - @abstractmethod def _get_hf_config(self) -> LlavaLikeConfig: - raise NotImplementedError + return self.ctx.get_hf_config(LlavaConfig) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} + def _get_vision_encoder_info(self): + return get_vision_encoder_info(self._get_hf_config()) - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": self._get_max_image_tokens()} + @abstractmethod + def _get_hf_processor(self) -> LlavaLikeProcessor: + raise NotImplementedError - def _get_mm_fields_config( + def _get_num_image_tokens( self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.batched("image"), + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self._get_hf_config() + vision_encoder_info = self._get_vision_encoder_info() + + return self._apply_feature_select_strategy( + hf_config.vision_feature_select_strategy, + vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ), ) def _apply_feature_select_strategy( @@ -133,31 +147,38 @@ def _apply_feature_select_strategy( msg = f"Unexpected feature select strategy: {strategy!r}" raise NotImplementedError(msg) - def _get_max_image_tokens(self) -> int: - hf_config = self._get_hf_config() - return self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, - self._vision_encoder_info.get_max_image_tokens(), - ) +class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo): - def _get_dummy_image_size(self) -> ImageSize: - image_size = self._vision_encoder_info.get_image_size() - return ImageSize(image_size, image_size) + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} - @abstractmethod - def _get_image_token(self) -> str: - raise NotImplementedError + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return {"image": self._get_max_image_tokens()} + + def _get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self._get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) - def _get_dummy_processor_inputs( + def _get_max_image_tokens(self) -> int: + target_width, target_height = self._get_image_size_with_most_features() + + return self._get_num_image_tokens( + image_width=target_width, + image_height=target_height, + ) + + def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: num_images = mm_counts.get("image", 0) - image_token = self._get_image_token() - target_width, target_height = self._get_dummy_image_size() + processor = self._get_hf_processor() + image_token = processor.image_token + target_width, target_height = self._get_image_size_with_most_features() mm_data = { "image": @@ -172,32 +193,32 @@ def _get_dummy_processor_inputs( ) -class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor): - - def _get_hf_config(self) -> LlavaConfig: - return self.ctx.get_hf_config(LlavaConfig) +class LlavaProcessingMixin(BaseLlavaProcessingMixin): - def _get_hf_processor(self) -> LlavaProcessor: + def _get_hf_processor(self): return self.ctx.get_hf_processor(LlavaProcessor) - def _get_image_token(self) -> str: - return self._get_hf_processor().image_token - def _get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - ) -> int: - hf_config = self._get_hf_config() +class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo): + pass - return self._apply_feature_select_strategy( - hf_config.vision_feature_select_strategy, - self._vision_encoder_info.get_num_image_tokens( - image_width=image_width, - image_height=image_height, - ), - ) + +class BaseLlavaMultiModalProcessor(LlavaProcessingMixin, + BaseMultiModalProcessor): + + # Copied from BaseMultiModalProcessor + @abstractmethod + def _get_profiling_info(self) -> BaseProfilingInfo: + raise NotImplementedError + + # Copied from BaseMultiModalProcessor + @abstractmethod + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + raise NotImplementedError def _get_prompt_replacements( self, @@ -232,16 +253,37 @@ def get_replacement(item_idx: int): ] -class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor): +class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return LlavaProfilingInfo(self.ctx) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) - def _get_hf_config(self) -> LlavaConfig: - return self.ctx.get_hf_config(LlavaConfig) - def _get_hf_processor(self) -> PixtralProcessor: +class PixtralHFProcessingMixin(BaseLlavaProcessingMixin): + + def _get_hf_processor(self): return self.ctx.get_hf_processor(PixtralProcessor) - def _get_image_token(self) -> str: - return self._get_hf_processor().image_token + +class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo): + pass + + +class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin, + BaseMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return PixtralHFProfilingInfo(self.ctx) def _call_hf_processor( self, @@ -270,6 +312,16 @@ def _call_hf_processor( return processed_outputs + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + def _get_prompt_replacements( self, mm_items: MultiModalDataItems, @@ -316,7 +368,7 @@ def _build_llava_or_pixtral_hf_processor( *, cache: Optional[ProcessingCache] = None, enable_sanity_checks: bool = True, -) -> BaseLlavaMultiModalProcessor: +) -> BaseMultiModalProcessor: hf_config = ctx.get_hf_config(LlavaConfig) if isinstance(hf_config.vision_config, PixtralVisionConfig): @@ -663,16 +715,13 @@ def load_weights(self, weights: Iterable[Tuple[str, class MantisMultiModalProcessor(LlavaMultiModalProcessor): - def _get_hf_processor(self): - return self.ctx.get_hf_processor(LlavaProcessor) - def apply( self, prompt_text: str, mm_data: MultiModalDataDict, hf_processor_mm_kwargs: Mapping[str, object], ) -> MultiModalInputsV2: - hf_config = self.ctx.get_hf_config(LlavaConfig) + hf_config = self._get_hf_config() image_token_id = hf_config.image_token_index # Assume that it doesn't depend on the image size diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index f79021596f915..c76ec164a3087 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,6 +1,6 @@ from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Final, Iterable, List, Literal, Mapping, Optional, + Protocol, Set, Tuple, TypedDict, Union) import numpy as np import torch @@ -17,12 +17,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.parse import ImageSize +from vllm.multimodal.profiling import BaseProfilingInfo from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector, - init_vision_tower_for_llava) +from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin, + BaseLlavaProfilingInfo, LlavaLikeConfig, + LlavaMultiModalProjector, init_vision_tower_for_llava) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, init_vllm_registered_model, maybe_prefix) @@ -60,35 +62,17 @@ class LlavaNextImageEmbeddingInputs(TypedDict): LlavaNextImageEmbeddingInputs] -class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): +class LlavaNextLikeConfig(LlavaLikeConfig, Protocol): + image_grid_pinpoints: Final[list[list[int]]] - def _get_hf_config(self) -> LlavaNextConfig: - return self.ctx.get_hf_config(LlavaNextConfig) - def _get_hf_processor(self) -> LlavaNextProcessor: - return self.ctx.get_hf_processor(LlavaNextProcessor) +class LlavaNextProcessingMixin(BaseLlavaProcessingMixin): - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - image_sizes=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.batched("image"), - ) - - def _get_image_token(self) -> str: - return self._get_hf_processor().image_token - - def _get_max_image_tokens(self) -> int: - largest_feature_size, _ = self._get_pinpoint_with_most_features() - return largest_feature_size + def _get_hf_config(self) -> LlavaNextLikeConfig: + return self.ctx.get_hf_config(LlavaNextConfig) - def _get_dummy_image_size(self) -> ImageSize: - _, pinpoint = self._get_pinpoint_with_most_features() - return pinpoint + def _get_hf_processor(self): + return self.ctx.get_hf_processor(LlavaNextProcessor) # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 def _get_num_image_tokens( @@ -98,7 +82,7 @@ def _get_num_image_tokens( image_height: int, ) -> int: hf_config = self._get_hf_config() - vision_encoder_info = self._vision_encoder_info + vision_encoder_info = self._get_vision_encoder_info() base_feature_size = self._apply_feature_select_strategy( hf_config.vision_feature_select_strategy, @@ -140,7 +124,7 @@ def _get_num_unpadded_features( current_height = npatches * num_patch_height current_width = npatches * num_patch_width - # NOTE: HF resizes based on float32 + # NOTE: Use float32 to remain consistent with HF output original_aspect_ratio = np.array(original_width / original_height, dtype=np.float32) current_aspect_ratio = np.array(current_width / current_height, @@ -164,11 +148,10 @@ def _get_num_unpadded_features( return (unpadded_features, newline_features) - def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]: - """ - Get the grid pinpoint with the most features and - the corresponding feature size. - """ + +class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo): + + def _get_image_size_with_most_features(self) -> ImageSize: hf_config = self._get_hf_config() largest_feature_size, largest_feature_pinpoint = 0, None @@ -183,7 +166,25 @@ def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]: if largest_feature_size == 0 or largest_feature_pinpoint is None: raise ValueError("Cannot have a largest feature size of 0!") - return largest_feature_size, largest_feature_pinpoint + return largest_feature_pinpoint + + +class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin, + BaseLlavaMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return LlavaNextProfilingInfo(self.ctx) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) @MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index ee6b89f0d4498..6e82cee1c95a4 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -15,11 +15,14 @@ from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors -from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, - VideoEmbeddingItems, VideoProcessorItems) -from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs, +from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, + NestedTensors) +from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems, + VideoProcessorItems) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + MultiModalDataItems, ProcessingMixin, PromptReplacement) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -28,7 +31,7 @@ from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import BaseVisionLanguageMultiModalProcessor +from .vision import get_vision_encoder_info class LlavaNextVideoPixelInputs(TypedDict): @@ -44,29 +47,16 @@ class LlavaNextVideoPixelInputs(TypedDict): """ -class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor): +class LlavaNextVideoProcessingMixin(ProcessingMixin): - def _get_hf_config(self) -> LlavaNextVideoConfig: + def _get_hf_config(self): return self.ctx.get_hf_config(LlavaNextVideoConfig) - def _get_hf_processor(self) -> LlavaNextVideoProcessor: - return self.ctx.get_hf_processor(LlavaNextVideoProcessor) - - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"video": 1} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - num_frames = self._get_dummy_num_frames(seq_len) - max_video_tokens = self._get_max_video_tokens(num_frames) - - return {"video": max_video_tokens} + def _get_vision_encoder_info(self): + return get_vision_encoder_info(self._get_hf_config()) - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict(pixel_values_videos=MultiModalFieldConfig.batched("video")) + def _get_hf_processor(self): + return self.ctx.get_hf_processor(LlavaNextVideoProcessor) def _get_num_frame_tokens( self, @@ -77,7 +67,8 @@ def _get_num_frame_tokens( hf_config = self._get_hf_config() spatial_pool_stride = hf_config.spatial_pool_stride - patch_grid_length = self._vision_encoder_info.get_patch_grid_length() + vision_encoder_info = self._get_vision_encoder_info() + patch_grid_length = vision_encoder_info.get_patch_grid_length() pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) return pooled_grid_length * pooled_grid_length @@ -96,18 +87,43 @@ def _get_num_video_tokens( return num_frame_tokens * num_frames - def _get_max_video_tokens(self, num_frames: int) -> int: - return self._get_num_video_tokens(image_width=999999, - image_height=999999, - num_frames=num_frames) + +class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin, + BaseProfilingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"video": 1} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + target_width, target_height = self._get_image_size_with_most_features() + + max_video_tokens = self._get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self._get_dummy_num_frames(seq_len), + ) + + return {"video": max_video_tokens} + + def _get_image_size_with_most_features(self) -> ImageSize: + vision_encoder_info = self._get_vision_encoder_info() + width = height = vision_encoder_info.get_image_size() + return ImageSize(width=width, height=height) def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self._get_image_size_with_most_features() + num_frames = 0 while True: next_num_frames = num_frames + 1 + next_max_tokens = self._get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + ) - if self._get_max_video_tokens(next_num_frames) > max_tokens: + if next_max_tokens > max_tokens: break num_frames = next_num_frames @@ -122,12 +138,45 @@ def _get_dummy_num_frames(self, seq_len: int) -> int: return max(max_total_frames // max(max_videos, 1), 1) - def _get_dummy_image_size(self) -> ImageSize: - image_size = self._vision_encoder_info.get_image_size() - return ImageSize(image_size, image_size) + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_videos = mm_counts.get("video", 0) + + processor = self._get_hf_processor() + video_token = processor.video_token + target_width, target_height = self._get_image_size_with_most_features() + + mm_data = { + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=self._get_dummy_num_frames(seq_len), + num_videos=num_videos, + ) + } + + return ProcessorInputs( + prompt_text=video_token * num_videos, + mm_data=mm_data, + ) + - def _get_video_token(self) -> str: - return self._get_hf_processor().video_token +class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin, + BaseMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return LlavaNextVideoProfilingInfo(self.ctx) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values_videos=MultiModalFieldConfig.batched("video")) def _get_prompt_replacements( self, @@ -162,36 +211,11 @@ def get_replacement(item_idx: int): ), ] - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - num_videos = mm_counts.get("video", 0) - - video_token = self._get_video_token() - target_width, target_height = self._get_dummy_image_size() - - mm_data = { - "video": - self._get_dummy_videos( - width=target_width, - height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), - num_videos=num_videos, - ) - } - - return ProcessorInputs( - prompt_text=video_token * num_videos, - mm_data=mm_data, - ) - # adopted from transformers modeling_llava_next_video.py class LlavaNextVideoPooler(nn.Module): - def __init__(self, config): + def __init__(self, config: LlavaNextVideoConfig): super().__init__() mode = config.spatial_pool_mode @@ -209,7 +233,7 @@ def __init__(self, config): raise ValueError( f"Unknown pooling mode: {mode}. Expected [`average`, `max`]") - def forward(self, image_features): + def forward(self, image_features: torch.Tensor): ori_width = int( math.sqrt(image_features.shape[1] * self.image_size // self.image_size)) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 5a3cdadc47cac..6dccc1e0d3b8d 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -1,7 +1,7 @@ import math from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Final, Iterable, List, Literal, Mapping, Optional, + Protocol, Set, Tuple, TypedDict, Union) import numpy as np import torch @@ -21,15 +21,16 @@ from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems, VideoProcessorItems) -from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs, - PromptReplacement) +from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP -from .llava import init_vision_tower_for_llava -from .llava_next import LlavaNextMultiModalProcessor +from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava +from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor, + LlavaNextProcessingMixin) from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -82,39 +83,17 @@ class LlavaOnevisionImageEmbeddingInputs(TypedDict): LlavaOnevisionVideoPixelInputs] -class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor): +class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol): + video_token_index: Final[int] - def _get_hf_config(self) -> LlavaOnevisionConfig: - return self.ctx.get_hf_config(LlavaOnevisionConfig) - - def _get_hf_processor(self) -> LlavaOnevisionProcessor: - return self.ctx.get_hf_processor(LlavaOnevisionProcessor) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None, "video": None} - - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - max_image_tokens = self._get_max_image_tokens() +class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin): - num_frames = self._get_dummy_num_frames(seq_len) - max_video_tokens = self._get_max_video_tokens(num_frames) - - return { - "image": max_image_tokens, - "video": max_video_tokens, - } + def _get_hf_config(self) -> LlavaOnevisionLikeConfig: + return self.ctx.get_hf_config(LlavaOnevisionConfig) - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - image_sizes=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.batched("image"), - pixel_values_videos=MultiModalFieldConfig.batched("video"), - ) + def _get_hf_processor(self): + return self.ctx.get_hf_processor(LlavaOnevisionProcessor) def _get_num_unpadded_features( self, @@ -128,7 +107,7 @@ def _get_num_unpadded_features( current_height = npatches * num_patch_height current_width = npatches * num_patch_width - # NOTE: HF resizes based on float32 + # NOTE: Use float32 to remain consistent with HF output original_aspect_ratio = np.array(original_width / original_height, dtype=np.float32) current_aspect_ratio = np.array(current_width / current_height, @@ -167,7 +146,8 @@ def _get_num_frame_tokens( hf_config = self._get_hf_config() spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2) - patch_grid_length = self._vision_encoder_info.get_patch_grid_length() + vision_encoder_info = self._get_vision_encoder_info() + patch_grid_length = vision_encoder_info.get_patch_grid_length() pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) return pooled_grid_length * pooled_grid_length @@ -186,18 +166,33 @@ def _get_num_video_tokens( return num_frame_tokens * num_frames + 1 # Newline token - def _get_max_video_tokens(self, num_frames: int) -> int: - return self._get_num_video_tokens(image_width=999999, - image_height=999999, - num_frames=num_frames) + +class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin, + BaseLlavaProfilingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return { + "image": self._get_max_image_tokens(), + "video": self._get_max_video_tokens(seq_len), + } def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self._get_image_size_with_most_features() + num_frames = 0 while True: next_num_frames = num_frames + 1 + next_max_tokens = self._get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + ) - if self._get_max_video_tokens(next_num_frames) > max_tokens: + if next_max_tokens > max_tokens: break num_frames = next_num_frames @@ -215,8 +210,65 @@ def _get_dummy_num_frames(self, seq_len: int) -> int: return max(max_total_frames // max(max_videos, 1), 1) - def _get_video_token(self) -> str: - return self._get_hf_processor().video_token + def _get_max_video_tokens(self, seq_len: int) -> int: + target_width, target_height = self._get_image_size_with_most_features() + + return self._get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self._get_dummy_num_frames(seq_len), + ) + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + processor = self._get_hf_processor() + image_token = processor.image_token + video_token = processor.video_token + target_width, target_height = self._get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=self._get_dummy_num_frames(seq_len), + num_videos=num_videos, + ) + } + + return ProcessorInputs( + prompt_text=image_token * num_images + video_token * num_videos, + mm_data=mm_data, + ) + + +class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin, + LlavaNextMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return LlavaOnevisionProfilingInfo(self.ctx) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.batched("video"), + ) def _call_hf_processor( self, @@ -235,7 +287,8 @@ def _call_hf_processor( mm_kwargs=mm_kwargs, ) - video_token = self._get_video_token() + processor = self._get_hf_processor() + video_token = processor.video_token # LLaVA-OneVision processor doesn't support multiple videos # with different sizes when converting back to tensors @@ -303,37 +356,6 @@ def get_video_replacement(item_idx: int): ), ] - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - num_images = mm_counts.get("image", 0) - num_videos = mm_counts.get("video", 0) - - image_token = self._get_image_token() - video_token = self._get_video_token() - target_width, target_height = self._get_dummy_image_size() - - mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos( - width=target_width, - height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), - num_videos=num_videos, - ) - } - - return ProcessorInputs( - prompt_text=image_token * num_images + video_token * num_videos, - mm_data=mm_data, - ) - class LlavaOnevisionMultiModalProjector(nn.Module): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 7aa9d58d1d348..c8418c14e5fdf 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -14,7 +14,7 @@ # limitations under the License. from collections.abc import Iterable, Mapping, Sequence from functools import cached_property -from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union +from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -28,22 +28,23 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputsV2, MultiModalKwargs, NestedTensors, PlaceholderRange) -from vllm.multimodal.parse import ImageEmbeddingItems, ImageProcessorItems +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + ImageSize) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, + MultiModalDataItems, ProcessingMixin, PromptReplacement, _BoundPromptReplacement, _PlaceholderInfo) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .clip import dummy_image_for_clip +from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -54,10 +55,6 @@ # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 32044 -# Result in the max possible feature size (h:w = 16:1) -MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000 -MAX_IMAGE_FEATURE_SIZE_WIDTH = 50 - CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, hidden_act="quick_gelu", hidden_size=1024, @@ -305,10 +302,17 @@ def add_image_newline(self, image_features_hd): return image_features_hd_newline -class Phi3VMultiModalProcessor(BaseMultiModalProcessor): +class Phi3VProcessingMixin(ProcessingMixin): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} + def _get_hf_processor( + self, + *, + num_crops: Optional[int] = None, + ) -> ProcessorMixin: + if num_crops is not None: + return self.ctx.get_hf_processor(num_crops=num_crops) + + return self.ctx.get_hf_processor() def _get_num_image_tokens( self, @@ -323,23 +327,55 @@ def _get_num_image_tokens( height=image_height, ) + +class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None} + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + target_width, target_height = self._get_image_size_with_most_features() + max_image_tokens = self._get_num_image_tokens( - image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + image_width=target_width, + image_height=target_height, ) return {"image": max_image_tokens} - def _get_hf_processor( + def _get_image_size_with_most_features(self) -> ImageSize: + # Result in the max possible feature size (h:w = 16:1) + return ImageSize(height=8000, width=50) + + def get_dummy_processor_inputs( self, - *, - num_crops: Optional[int] = None, - ) -> ProcessorMixin: - if num_crops is not None: - return self.ctx.get_hf_processor(num_crops=num_crops) + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) - return self.ctx.get_hf_processor() + target_width, target_height = self._get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + hf_processor = self._get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + + return ProcessorInputs( + prompt_text="".join(image_tokens[:num_images]), + mm_data=mm_data, + ) + + +class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return Phi3VProfilingInfo(self.ctx) def _call_hf_processor( self, @@ -377,10 +413,10 @@ def _get_mm_fields_config( def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor() + hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) image_tokens: list[str] = hf_processor.img_tokens # type: ignore tokenizer = self._get_tokenizer() @@ -442,28 +478,6 @@ def _apply_prompt_replacements( return token_ids, text, placeholders - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - num_images = mm_counts.get("image", 0) - - data = dummy_image_for_clip( - CLIP_VIT_LARGE_PATCH14_336_CONFIG, - num_images, - image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, - image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, - ) - - hf_processor = self._get_hf_processor() - image_tokens: list[str] = hf_processor.img_tokens # type: ignore - - return ProcessorInputs( - prompt_text="".join(image_tokens[:num_images]), - mm_data=data, - ) - def apply( self, prompt_text: str, diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index bc3bb1f79b407..a7bb3425ed17c 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -20,8 +20,8 @@ # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" from functools import cached_property -from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, - Union) +from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -40,8 +40,9 @@ NestedTensors) from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, + MultiModalDataItems, ProcessingMixin, PromptReplacement) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal, SupportsPP @@ -79,28 +80,70 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): return feat_lengths, output_lengths -class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): +class Qwen2AudioProcessingMixin(ProcessingMixin): + + def _get_hf_config(self): + return self.ctx.get_hf_config(Qwen2AudioConfig) + + def _get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> Qwen2AudioProcessor: + return self.ctx.get_hf_processor(Qwen2AudioProcessor) + + def _get_feature_extractor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> WhisperFeatureExtractor: + hf_processor = self._get_hf_processor(sampling_rate=sampling_rate) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + +class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) + hf_config = self._get_hf_config() max_source_positions = hf_config.audio_config.max_source_positions max_output_lengths = (max_source_positions - 2) // 2 + 1 return {"audio": max_output_lengths} - def _get_hf_processor( + def get_dummy_processor_inputs( self, - *, - # Ignored in initialization - sampling_rate: Optional[int] = None, - ) -> Qwen2AudioProcessor: - return self.ctx.get_hf_processor(Qwen2AudioProcessor) + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + feature_extractor = self._get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + return ProcessorInputs( + prompt_text="<|AUDIO|>" * num_audios, + mm_data=mm_data, + ) + - def _get_feature_extractor(self) -> WhisperFeatureExtractor: - return self._get_hf_processor().feature_extractor # type: ignore +class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin, + BaseMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return Qwen2AudioProfilingInfo(self.ctx) def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self._get_feature_extractor() @@ -110,7 +153,7 @@ def _call_hf_processor( self, prompt: str, mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], + mm_kwargs: Mapping[str, Any], ) -> BatchFeature: mm_data = dict(mm_data) audios = mm_data.pop("audios", []) @@ -118,7 +161,7 @@ def _call_hf_processor( if audios: mm_data["audios"] = audios - feature_extractor = self._get_feature_extractor() + feature_extractor = self._get_feature_extractor(**mm_kwargs) mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, @@ -151,7 +194,7 @@ def _get_prompt_replacements( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) + hf_config = self._get_hf_config() placeholder = hf_config.audio_token_index feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") @@ -191,27 +234,6 @@ def _always_apply_prompt_replacements(self) -> bool: # tokens than the number of audio items) return True - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - feature_extractor = self._get_feature_extractor() - - sampling_rate = feature_extractor.sampling_rate - audio_len = feature_extractor.chunk_length * sampling_rate - num_audios = mm_counts.get("audio", 0) - - mm_data = { - "audio": - self._get_dummy_audios(length=audio_len, num_audios=num_audios) - } - - return ProcessorInputs( - prompt_text="<|AUDIO|>" * num_audios, - mm_data=mm_data, - ) - @MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor) class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index abca85e0e2024..a5c2fb9e84df3 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -59,8 +59,9 @@ from vllm.multimodal.parse import (ImageSize, ModalityDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, + MultiModalDataItems, ProcessingMixin, PromptReplacement) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import uses_mrope @@ -708,10 +709,44 @@ def _parse_video_data( return super()._parse_video_data(data) -class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): +class Qwen2VLProcessingMixin(ProcessingMixin): - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None, "video": None} + def _get_hf_config(self): + return self.ctx.get_hf_config(Qwen2VLConfig) + + def _get_hf_processor( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + ) -> Qwen2VLProcessor: + hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor) + image_processor = hf_processor.image_processor # type: ignore + assert isinstance(image_processor, Qwen2VLImageProcessor) + + if min_pixels: + image_processor.min_pixels = min_pixels + if max_pixels: + image_processor.max_pixels = max_pixels + if max_pixels or min_pixels: + image_processor.size = { + "min_pixels": image_processor.min_pixels, + "max_pixels": image_processor.max_pixels, + } + + return hf_processor + + def _get_image_processor( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + ): + hf_processor = self._get_hf_processor(min_pixels=min_pixels, + max_pixels=max_pixels) + image_processor = hf_processor.image_processor # type: ignore + assert isinstance(image_processor, Qwen2VLImageProcessor) + return image_processor def _get_vision_info( self, @@ -721,14 +756,13 @@ def _get_vision_info( num_frames: int = 1, do_resize: bool = True, ) -> tuple[ImageSize, int]: - hf_config = self.ctx.get_hf_config(Qwen2VLConfig) + hf_config = self._get_hf_config() vision_config = hf_config.vision_config patch_size = vision_config.patch_size merge_size = vision_config.spatial_merge_size temporal_patch_size = vision_config.temporal_patch_size - hf_processor = self._get_hf_processor() - image_processor = self._get_image_processor(hf_processor) + image_processor = self._get_image_processor() if do_resize: resized_height, resized_width = smart_resize( @@ -753,7 +787,45 @@ def _get_vision_info( return preprocessed_size, num_vision_tokens - def _get_dummy_image_size(self) -> ImageSize: + def _get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + ) + return num_image_tokens + + def _get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + ) -> int: + _, num_video_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + num_frames=num_frames, + ) + return num_video_tokens + + +class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo): + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return { + "image": self._get_max_image_tokens(), + "video": self._get_max_video_tokens(seq_len), + } + + def _get_image_size_with_most_features(self) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=9999999, image_height=9999999, @@ -761,27 +833,27 @@ def _get_dummy_image_size(self) -> ImageSize: return max_image_size def _get_max_image_tokens(self) -> int: - _, max_image_tokens = self._get_vision_info( - image_width=9999999, - image_height=9999999, - ) - return max_image_tokens + target_width, target_height = self._get_image_size_with_most_features() - def _get_max_video_tokens(self, num_frames: int) -> int: - _, max_video_tokens = self._get_vision_info( - image_width=9999999, - image_height=9999999, - num_frames=num_frames, + return self._get_num_image_tokens( + image_width=target_width, + image_height=target_height, ) - return max_video_tokens def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self._get_image_size_with_most_features() + num_frames = 0 while True: next_num_frames = num_frames + 1 + next_max_tokens = self._get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + ) - if self._get_max_video_tokens(next_num_frames) > max_tokens: + if next_max_tokens > max_tokens: break num_frames = next_num_frames @@ -797,56 +869,73 @@ def _get_dummy_num_frames(self, seq_len: int) -> int: max_total_frames = self._get_max_video_frames(seq_len - max_image_tokens) - return max(max_total_frames // max(max_videos, 1), 1) + num_frames = max(max_total_frames // max(max_videos, 1), 1) - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - max_image_tokens = self._get_max_image_tokens() + # Temporary workaround for https://github.com/huggingface/transformers/issues/35412 + if num_frames > 1 and num_frames % 2 == 1: + num_frames += 1 - num_frames = self._get_dummy_num_frames(seq_len) - max_video_tokens = self._get_max_video_tokens(num_frames) + return num_frames - return { - "image": max_image_tokens, - "video": max_video_tokens, + def _get_max_video_tokens(self, seq_len: int) -> int: + target_width, target_height = self._get_image_size_with_most_features() + + return self._get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self._get_dummy_num_frames(seq_len), + ) + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + hf_processor = self._get_hf_processor() + image_token: str = hf_processor.image_token + video_token: str = hf_processor.video_token + target_width, target_height = self._get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=self._get_dummy_num_frames(seq_len), + num_videos=num_videos, + ) } - def _get_data_parser(self) -> MultiModalDataParser: - return Qwen2MultiModalDataParser() + return ProcessorInputs( + prompt_text=image_token * num_images + video_token * num_videos, + mm_data=mm_data, + ) - def _get_image_processor(self, hf_processor: Qwen2VLProcessor): - image_processor = hf_processor.image_processor # type: ignore - assert isinstance(image_processor, Qwen2VLImageProcessor) - return image_processor - def _get_hf_processor( - self, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, - ) -> Qwen2VLProcessor: - hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor) - image_processor = self._get_image_processor(hf_processor) +class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin, + BaseMultiModalProcessor): - if min_pixels: - image_processor.min_pixels = min_pixels - if max_pixels: - image_processor.max_pixels = max_pixels - if max_pixels or min_pixels: - image_processor.size = { - "min_pixels": image_processor.min_pixels, - "max_pixels": image_processor.max_pixels, - } + def _get_profiling_info(self) -> BaseProfilingInfo: + return Qwen2VLProfilingInfo(self.ctx) - return hf_processor + def _get_data_parser(self) -> MultiModalDataParser: + return Qwen2MultiModalDataParser() def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor() - image_processor = self._get_image_processor(hf_processor) + hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self._get_image_processor(**hf_processor_mm_kwargs) # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has # image_token and video_token registered @@ -901,38 +990,6 @@ def _get_mm_fields_config( video_grid_thw=MultiModalFieldConfig.batched("video"), ) - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - num_images = mm_counts.get("image", 0) - num_videos = mm_counts.get("video", 0) - - hf_processor = self._get_hf_processor() - image_token: str = hf_processor.image_token - video_token: str = hf_processor.video_token - target_width, target_height = self._get_dummy_image_size() - - mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images), - "video": - self._get_dummy_videos( - width=target_width, - height=target_height, - num_frames=self._get_dummy_num_frames(seq_len), - num_videos=num_videos, - ) - } - - return ProcessorInputs( - prompt_text=image_token * num_images + video_token * num_videos, - mm_data=mm_data, - ) - @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor) class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 6ad4661e3bb8d..ba823acecbb56 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,8 +3,8 @@ import math from functools import cached_property -from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, - TypedDict, Union) +from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, + Tuple, TypedDict, Union) import torch import torch.utils.checkpoint @@ -26,8 +26,9 @@ NestedTensors) from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing import (BaseMultiModalProcessor, - MultiModalDataItems, ProcessorInputs, + MultiModalDataItems, ProcessingMixin, PromptReplacement) +from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -55,7 +56,30 @@ class UltravoxAudioEmbeddingInputs(TypedDict): UltravoxAudioEmbeddingInputs] -class UltravoxMultiModalProcessor(BaseMultiModalProcessor): +class UltravoxProcessingMixin(ProcessingMixin): + + def _get_hf_processor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> ProcessorMixin: + return self.ctx.get_hf_processor() + + def _get_feature_extractor( + self, + *, + # Ignored in initialization + sampling_rate: Optional[int] = None, + ) -> WhisperFeatureExtractor: + hf_processor = self._get_hf_processor(sampling_rate=sampling_rate) + audio_processor = hf_processor.audio_processor # type: ignore + feature_extractor = audio_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, WhisperFeatureExtractor) + return feature_extractor + + +class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} @@ -67,17 +91,33 @@ def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"audio": max_audio_tokens} - def _get_hf_processor( + def get_dummy_processor_inputs( self, - *, - # Ignored in initialization - sampling_rate: Optional[int] = None, - ) -> ProcessorMixin: - return self.ctx.get_hf_processor() + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + feature_extractor = self._get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + mm_data = { + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + return ProcessorInputs( + prompt_text="<|audio|>" * num_audios, + mm_data=mm_data, + ) - def _get_feature_extractor(self) -> WhisperFeatureExtractor: - hf_processor = self._get_hf_processor() - return hf_processor.audio_processor.feature_extractor # type: ignore + +class UltravoxMultiModalProcessor(UltravoxProcessingMixin, + BaseMultiModalProcessor): + + def _get_profiling_info(self) -> BaseProfilingInfo: + return UltravoxProfilingInfo(self.ctx) def _get_data_parser(self) -> MultiModalDataParser: feature_extractor = self._get_feature_extractor() @@ -155,10 +195,10 @@ def _get_mm_fields_config( def _get_prompt_replacements( self, mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], + hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: - hf_processor = self._get_hf_processor() + hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) placeholder = hf_processor.audio_token_replacement # type: ignore def get_replacement_ultravox(item_idx: int): @@ -173,27 +213,6 @@ def get_replacement_ultravox(item_idx: int): ) ] - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - feature_extractor = self._get_feature_extractor() - - sampling_rate = feature_extractor.sampling_rate - audio_len = feature_extractor.chunk_length * sampling_rate - num_audios = mm_counts.get("audio", 0) - - mm_data = { - "audio": - self._get_dummy_audios(length=audio_len, num_audios=num_audios) - } - - return ProcessorInputs( - prompt_text="<|audio|>" * num_audios, - mm_data=mm_data, - ) - class StackAudioFrames(nn.Module): """ diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 014f02ee10a1b..8516c9f7066f7 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -1,12 +1,8 @@ from abc import ABC, abstractmethod -from typing import Final, Generic, Optional, Protocol, TypeVar +from typing import Final, Generic, Protocol, TypeVar from transformers import PretrainedConfig -from vllm.multimodal.processing import (BaseMultiModalProcessor, - InputProcessingContext, - ProcessingCache) - _C = TypeVar("_C", bound=PretrainedConfig) @@ -43,12 +39,18 @@ def get_patch_grid_length(self) -> int: raise NotImplementedError -def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo: +class VisionLanguageConfig(Protocol): + vision_config: Final[PretrainedConfig] + + +def get_vision_encoder_info( + hf_config: VisionLanguageConfig) -> VisionEncoderInfo: # Avoid circular imports from .clip import CLIPEncoderInfo, CLIPVisionConfig from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig from .siglip import SiglipEncoderInfo, SiglipVisionConfig + vision_config = hf_config.vision_config if isinstance(vision_config, CLIPVisionConfig): return CLIPEncoderInfo(vision_config) if isinstance(vision_config, PixtralVisionConfig): @@ -58,26 +60,3 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo: msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) - - -class VisionLanguageConfig(Protocol): - vision_config: Final[PretrainedConfig] - - -class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor): - - def __init__(self, - ctx: InputProcessingContext, - *, - cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True) -> None: - super().__init__(ctx, - cache=cache, - enable_sanity_checks=enable_sanity_checks) - - vision_config = self._get_hf_config().vision_config - self._vision_encoder_info = vision_encoder_info(vision_config) - - @abstractmethod - def _get_hf_config(self) -> VisionLanguageConfig: - raise NotImplementedError diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index ebc16b817684a..933c1d3aff0cb 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -8,11 +8,10 @@ from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union import numpy as np -import numpy.typing as npt import torch from blake3 import blake3 from PIL import Image -from transformers import BatchFeature, ProcessorMixin +from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger @@ -24,6 +23,7 @@ MultiModalInputsV2, MultiModalKwargs, MultiModalKwargsItem, PlaceholderRange) from .parse import MultiModalDataItems, MultiModalDataParser +from .profiling import BaseProfilingInfo logger = init_logger(__name__) @@ -466,14 +466,6 @@ def find_mm_placeholders( return dict(full_groupby_modality(it)) -@dataclass -class ProcessorInputs: - """Keyword arguments to :meth:`BaseMultiModalProcessor`.""" - prompt_text: str - mm_data: MultiModalDataDict - hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) - - class ProcessingCache: def __init__(self, capacity: int) -> None: @@ -585,9 +577,33 @@ def put( self._cache.put(cache_key, output_kwargs) -class BaseMultiModalProcessor(ABC): +class ProcessingMixin: + """ + Contains helper functions to perform processing. + + Not to be confused with :class:`transformers.ProcessorMixin`. + """ + ctx: InputProcessingContext + + def _get_tokenizer(self) -> AnyTokenizer: + return self.ctx.tokenizer + + def _get_hf_config(self) -> PretrainedConfig: + return self.ctx.get_hf_config() + + def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin: + """ + Subclasses can override this method to handle + specific kwargs from model config or user inputs. + """ + return self.ctx.get_hf_processor(**kwargs) + + +class BaseMultiModalProcessor(ProcessingMixin, ABC): """ Abstract base class to process multi-modal inputs to be used in vLLM. + + Not to be confused with :class:`transformers.ProcessorMixin`. """ def __init__(self, @@ -601,6 +617,9 @@ def __init__(self, self.cache = cache self.enable_sanity_checks = enable_sanity_checks + self.data_parser = self._get_data_parser() + self.profiling_info = self._get_profiling_info() + def __call__( self, prompt: str, @@ -609,32 +628,9 @@ def __call__( ) -> MultiModalInputsV2: return self.apply(prompt, mm_data, hf_processor_mm_kwargs) - @abstractmethod - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - """ - Return the maximum supported number of items for each modality. - - A value of `None` means unlimited number of items. - - Omitting a modality from the returned dictionary means that - it is not supported at all. - """ - raise NotImplementedError - - @abstractmethod - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - """ - Get the maximum possible number of tokens per data item - for each modality. - - The dictionary returned by this method should have the same - keys as that returned by :meth:`get_supported_mm_limits`. - """ - raise NotImplementedError - def _get_data_parser(self) -> MultiModalDataParser: """ - Construct a data parser to preprocess multi-modal data items + Construct a parser to preprocess multi-modal data items before passing them to :meth:`_get_hf_mm_data`. You can support additional modalities by creating a subclass @@ -642,15 +638,12 @@ def _get_data_parser(self) -> MultiModalDataParser: """ return MultiModalDataParser() - def _get_hf_processor(self) -> ProcessorMixin: + def _get_profiling_info(self) -> BaseProfilingInfo: """ - Subclasses can add keyword arguments to this method to accept - additional kwargs from model config or user inputs. + Get the profiling information to find the worst-case memory usage of + the model. """ - return self.ctx.get_hf_processor() - - def _get_tokenizer(self) -> AnyTokenizer: - return self.ctx.tokenizer + raise NotImplementedError def _to_mm_items( self, @@ -660,8 +653,7 @@ def _to_mm_items( Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems` before passing them to :meth:`_get_hf_mm_data`. """ - parser = self._get_data_parser() - mm_items = parser.parse_mm_data(mm_data) + mm_items = self.data_parser.parse_mm_data(mm_data) mm_limits = self.ctx.get_mm_config().limit_per_prompt for modality, items in mm_items.items(): @@ -799,7 +791,7 @@ def _apply_hf_processor_missing( # Some HF processors (e.g. Qwen2-VL) expect corresponding # multi-modal tokens to be in the prompt text - dummy_inputs = self._get_dummy_processor_inputs( + dummy_inputs = self.profiling_info.get_dummy_processor_inputs( self.ctx.model_config.max_model_len, mm_missing_counts, ) @@ -1133,73 +1125,14 @@ def apply( mm_placeholders=mm_placeholder_ranges, ) - def _get_dummy_audios( - self, - *, - length: int, - num_audios: int, - ) -> list[npt.NDArray]: - audio = np.zeros((length, )) - return [audio] * num_audios - - def _get_dummy_images( - self, - *, - width: int, - height: int, - num_images: int, - ) -> list[Image.Image]: - image = Image.new("RGB", (width, height), color=0) - return [image] * num_images - - def _get_dummy_videos( - self, - *, - width: int, - height: int, - num_frames: int, - num_videos: int, - ) -> list[npt.NDArray]: - video = np.zeros((num_frames, width, height, 3)) - return [video] * num_videos - - @abstractmethod - def _get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> ProcessorInputs: - """ - Build the multi-modal portion of the input which, after processing, - results in `mm_max_tokens` in :meth:`get_dummy_data`. - """ - raise NotImplementedError - - def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]: - mm_limit_per_prompt = self.ctx.get_mm_config().limit_per_prompt - supported_mm_limits = self.get_supported_mm_limits() - - mm_limits = { - modality: mm_limit_per_prompt.get(modality, 1) - for modality in supported_mm_limits - } - - for modality, supported_limit in supported_mm_limits.items(): - limit = mm_limits[modality] - if supported_limit is not None and supported_limit < limit: - raise ValueError( - f"You set {modality}={limit} (or defaulted to 1) in " - f"`--limit-mm-per-prompt`, but this model only supports " - f"at most {supported_limit} {modality} items.") - - return mm_limits - def _get_dummy_mm_inputs( self, seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalInputsV2: - processor_inputs = self._get_dummy_processor_inputs(seq_len, mm_counts) + profiling = self.profiling_info + processor_inputs = profiling.get_dummy_processor_inputs( + seq_len, mm_counts) return self.apply( prompt_text=processor_inputs.prompt_text, @@ -1211,8 +1144,9 @@ def get_dummy_data(self, seq_len: int) -> DummyData: # Avoid circular import from vllm.sequence import SequenceData - mm_counts = self._get_and_validate_dummy_mm_counts() - mm_max_tokens_per_item = self.get_mm_max_tokens_per_item(seq_len) + profiling = self.profiling_info + mm_counts = profiling.get_mm_limits() + mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len) if mm_counts.keys() != mm_max_tokens_per_item.keys(): raise AssertionError( "The keys returned by `get_supported_mm_limits`" diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py new file mode 100644 index 0000000000000..2ecf0db1a485d --- /dev/null +++ b/vllm/multimodal/profiling.py @@ -0,0 +1,121 @@ +from abc import ABC, abstractmethod +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import numpy.typing as npt +from PIL import Image + +from vllm.inputs import InputProcessingContext +from vllm.logger import init_logger + +from .inputs import MultiModalDataDict + +logger = init_logger(__name__) + + +@dataclass +class ProcessorInputs: + """Keyword arguments to :meth:`BaseMultiModalProcessor`.""" + prompt_text: str + mm_data: MultiModalDataDict + hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) + + +class BaseProfilingInfo(ABC): + """ + Abstract base class that provides the information necessary to profile + multi-modal models. + """ + + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__() + + self.ctx = ctx + + @abstractmethod + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + """ + Return the maximum supported number of items for each modality. + + A value of `None` means unlimited number of items. + + Omitting a modality from the returned dictionary means that + it is not supported at all. + """ + raise NotImplementedError + + @abstractmethod + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + """ + Get the maximum possible number of tokens per data item + for each modality. + + The dictionary returned by this method should have the same + keys as that returned by :meth:`get_supported_mm_limits`. + """ + raise NotImplementedError + + @abstractmethod + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + """ + Build the multi-modal portion of the input which, after processing, + results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`. + """ + raise NotImplementedError + + def _get_dummy_audios( + self, + *, + length: int, + num_audios: int, + ) -> list[npt.NDArray]: + audio = np.zeros((length, )) + return [audio] * num_audios + + def _get_dummy_images( + self, + *, + width: int, + height: int, + num_images: int, + ) -> list[Image.Image]: + image = Image.new("RGB", (width, height), color=0) + return [image] * num_images + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + ) -> list[npt.NDArray]: + video = np.zeros((num_frames, width, height, 3)) + return [video] * num_videos + + def get_mm_limits(self) -> Mapping[str, int]: + mm_config = self.ctx.get_mm_config() + mm_limit_per_prompt = mm_config.limit_per_prompt + + supported_mm_limits = self.get_supported_mm_limits() + + mm_limits = { + modality: mm_limit_per_prompt.get(modality, 1) + for modality in supported_mm_limits + } + + for modality, supported_limit in supported_mm_limits.items(): + limit = mm_limits[modality] + if supported_limit is not None and supported_limit < limit: + raise ValueError( + f"You set {modality}={limit} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but this model only supports " + f"at most {supported_limit} {modality} items.") + + return mm_limits diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index fb4389dc4df42..f75a594a4c4e0 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -224,7 +224,7 @@ def get_max_tokens_per_item_by_modality( tokenizer = cached_get_tokenizer(model_config.tokenizer) processor = self.create_processor(model_config, tokenizer) seq_len = model_config.max_model_len - return processor.get_mm_max_tokens_per_item(seq_len) + return processor.profiling_info.get_mm_max_tokens_per_item(seq_len) return { key: plugin.get_max_multimodal_tokens(model_config)