From 83b824c8b4ee55824b30f0509fd312b0cddb35e5 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 11 Apr 2025 00:06:58 +0800 Subject: [PATCH] [VLM] Remove `BaseProcessingInfo.get_mm_max_tokens_per_item` (#16408) Signed-off-by: DarkLight1337 --- docs/source/contributing/model/multimodal.md | 240 +++--------------- .../multimodal/processing/test_llama4.py | 5 - vllm/model_executor/models/aria.py | 7 - vllm/model_executor/models/aya_vision.py | 25 -- vllm/model_executor/models/blip2.py | 7 - vllm/model_executor/models/chameleon.py | 7 - vllm/model_executor/models/clip.py | 3 - vllm/model_executor/models/deepseek_vl2.py | 14 - vllm/model_executor/models/florence2.py | 11 +- vllm/model_executor/models/fuyu.py | 15 -- vllm/model_executor/models/gemma3_mm.py | 16 -- vllm/model_executor/models/glm4v.py | 7 - vllm/model_executor/models/h2ovl.py | 23 -- vllm/model_executor/models/idefics3.py | 16 -- vllm/model_executor/models/internvl.py | 16 -- vllm/model_executor/models/llava.py | 7 - .../model_executor/models/llava_next_video.py | 16 -- vllm/model_executor/models/llava_onevision.py | 10 - vllm/model_executor/models/minicpmo.py | 11 - vllm/model_executor/models/minicpmv.py | 12 - vllm/model_executor/models/mistral3.py | 15 -- vllm/model_executor/models/mllama.py | 10 - vllm/model_executor/models/mllama4.py | 19 -- vllm/model_executor/models/molmo.py | 16 -- vllm/model_executor/models/paligemma.py | 48 ++-- vllm/model_executor/models/phi3v.py | 15 -- vllm/model_executor/models/pixtral.py | 23 -- .../models/prithvi_geospatial_mae.py | 3 - vllm/model_executor/models/qwen2_audio.py | 11 - vllm/model_executor/models/qwen2_vl.py | 10 - vllm/model_executor/models/qwen_vl.py | 7 - vllm/model_executor/models/siglip.py | 3 - vllm/model_executor/models/skyworkr1v.py | 16 -- vllm/model_executor/models/ultravox.py | 12 - vllm/model_executor/models/vision.py | 4 - vllm/model_executor/models/whisper.py | 11 +- vllm/multimodal/processing.py | 15 -- vllm/multimodal/profiling.py | 65 ++--- vllm/multimodal/registry.py | 10 +- 39 files changed, 104 insertions(+), 677 deletions(-) diff --git a/docs/source/contributing/model/multimodal.md b/docs/source/contributing/model/multimodal.md index 0c749633..03d830fe 100644 --- a/docs/source/contributing/model/multimodal.md +++ b/docs/source/contributing/model/multimodal.md @@ -121,17 +121,21 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": 1} ``` -### Maximum number of placeholder feature tokens +## 3. Specify dummy inputs -Also, override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item` -to return the maximum number of placeholder feature tokens per input item for each modality. +Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for +HF processing as well as memory profiling. -When calling the model, the output embeddings from the visual encoder are assigned to the input positions -containing placeholder feature tokens. Therefore, the number of placeholder feature tokens should be equal -to the size of the output embeddings. +### For memory profiling -:::::{tab-set} -::::{tab-item} Basic example: LLaVA +Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs` +to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of +the model so that vLLM can reserve the correct amount of memory for it. + +Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens. + +::::{tab-set} +:::{tab-item} Basic example: LLaVA :sync: llava Looking at the code of HF's `LlavaForConditionalGeneration`: @@ -240,7 +244,7 @@ def get_num_image_tokens( ``` Notice that the number of image tokens doesn't depend on the image width and height. -So, we can calculate the maximum number of image tokens using any image size: +We can simply use a dummy `image_size`: ```python def get_image_size_with_most_features(self) -> ImageSize: @@ -248,33 +252,35 @@ def get_image_size_with_most_features(self) -> ImageSize: width = height = hf_config.image_size return ImageSize(width=width, height=height) -def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) -``` - -And thus, we can override the method as: - -```python -def get_mm_max_tokens_per_item( +def get_dummy_processor_inputs( self, seq_len: int, mm_counts: Mapping[str, int], -) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} +) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + image_token = processor.image_token + + hf_config = self.get_hf_config() + target_width, target_height = self.info.get_image_size_with_most_features() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + return ProcessorInputs( + prompt_text=image_token * num_images, + mm_data=mm_data, + ) ``` -:::{note} -Our [actual code](gh-file:vllm/model_executor/models/llava.py) is more abstracted to support vision encoders other than CLIP. ::: -:::: - -::::{tab-item} Non-consecutive feature tokens: Fuyu +:::{tab-item} No input placeholders: Fuyu :sync: fuyu Looking at the code of HF's `FuyuForCausalLM`: @@ -394,188 +400,16 @@ num_patches_per_dim_w = image_width // patch_width num_patches = num_patches_per_dim_h * num_patches_per_dim_w ``` -We can calculate this in vLLM using this code: - -```python -def get_num_image_patches( - self, - *, - image_width: int, - image_height: int, -) -> int: - image_processor = self.get_image_processor() - target_width = image_processor.size["width"] - target_height = image_processor.size["height"] - patch_width = image_processor.patch_size["width"] - patch_height = image_processor.patch_size["height"] - - if not (image_width <= target_width and image_height <= target_height): - height_scale_factor = target_height / image_height - width_scale_factor = target_width / image_width - optimal_scale_factor = min(height_scale_factor, width_scale_factor) - - image_height = int(image_height * optimal_scale_factor) - image_width = int(image_width * optimal_scale_factor) - - ncols = math.ceil(image_width / patch_width) - nrows = math.ceil(image_height / patch_height) - return ncols * nrows -``` - -These image patches correspond to placeholder tokens (`|SPEAKER|`). However, the processor also -inserts newline tokens (`|NEWLINE|`) as shown here: - -```python -# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L654-L670 -tensor_of_image_ids = torch.full( - [num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device -) -patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0) -assert num_patches == patches.shape[0] - -if variable_sized: - # Now terminate each line with |NEWLINE|. - tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width) - newline_ids = torch.full( - [tensor_of_image_ids.shape[0], 1], - image_newline_id, - dtype=torch.int32, - device=image_input.device, - ) - tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1) - tensor_of_image_ids = tensor_of_image_ids.reshape(-1) -``` - -So, the layout of tokens for an image is: - -``` -|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| -|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| -... -|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE| -``` - -This makes the placeholder tokens non-consecutive in the prompt. -Since vLLM requires the feature tokens to be consecutive, **we also treat the newline tokens as feature tokens**. - -So overall, the total number of feature tokens is - -```python -def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, -) -> int: - image_processor = self.get_image_processor() - target_width = image_processor.size["width"] - target_height = image_processor.size["height"] - patch_width = image_processor.patch_size["width"] - patch_height = image_processor.patch_size["height"] - - if not (image_width <= target_width and image_height <= target_height): - height_scale_factor = target_height / image_height - width_scale_factor = target_width / image_width - optimal_scale_factor = min(height_scale_factor, width_scale_factor) - - image_height = int(image_height * optimal_scale_factor) - image_width = int(image_width * optimal_scale_factor) - - ncols = math.ceil(image_width / patch_width) - nrows = math.ceil(image_height / patch_height) - return (ncols + 1) * nrows -``` - -To calculate the maximum number of image tokens, recall that input images are first resized -to fit within `image_processor.size`. The maximum possible dimensions of the image before -being converted into patches is therefore equal to `image_processor.size`. +These image patches correspond to placeholder tokens (`|SPEAKER|`). So, we just need to maximize the number of image patches. Since input images are first resized +to fit within `image_processor.size`, we can maximize the number of image patches by inputting an image with size equal to `image_processor.size`. ```python def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_image_processor() return ImageSize(width=image_processor.size["width"], height=image_processor.size["height"]) - -def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) ``` -And thus, we can override the method as: - -```python -def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], -) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} -``` - -:::{note} -Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) returns `ncols` and `nrows` directly instead of the total token count. -This is because `ncols` and `nrows` are used to specify the layout of the feature tokens (as shown in Step 4 of this guide). -::: - -:::: -::::: - -## 3. Specify dummy inputs - -Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for -HF processing as well as memory profiling. - -### For memory profiling - -Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs` -to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of -the model so that vLLM can reserve the correct amount of memory for it. - -Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed based -on the code for {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item`. - -::::{tab-set} -:::{tab-item} Basic example: LLaVA -:sync: llava - -Making use of the `get_image_size_with_most_features` method implemented in Step 2: - -```python -def get_dummy_processor_inputs( - self, - seq_len: int, - mm_counts: Mapping[str, int], -) -> ProcessorInputs: - num_images = mm_counts.get("image", 0) - - processor = self.info.get_hf_processor() - image_token = processor.image_token - - hf_config = self.get_hf_config() - target_width, target_height = self.info.get_image_size_with_most_features() - - mm_data = { - "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) - } - - return ProcessorInputs( - prompt_text=image_token * num_images, - mm_data=mm_data, - ) -``` - -::: - -:::{tab-item} No input placeholders: Fuyu -:sync: fuyu - Fuyu does not expect image placeholders in the inputs to HF processor, so the dummy prompt text is empty regardless of the number of images. Otherwise, the logic of this method is very similar to LLaVA: diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index 578dcd4a..2bfc2785 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -76,11 +76,6 @@ def test_processor_override( if v == config.boi_token_index] # patch sizes and masks - patch_token_id = vocab[hf_processor.img_patch_token] - num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id) - mm_counts = {"image": num_imgs} - assert num_patches / num_imgs <= \ - processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"] num_patches_per_chunk = processor.info.get_patch_per_chunk( config.vision_config) assert prompt_token_ids.count(config.image_token_index) \ diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index af340fef..23b8ef89 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -408,13 +408,6 @@ class AriaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - def get_num_image_tokens(self) -> int: hf_config = self.get_hf_config() return max(hf_config.projector_patch_to_query_dict.values()) diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index 929c8f2a..cdec3160 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -117,31 +117,6 @@ class AyaVisionProcessingInfo(BaseProcessingInfo): def get_image_processor(self) -> GotOcr2ImageProcessor: return self.get_hf_processor().image_processor - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - - def get_max_image_tokens(self) -> int: - hf_processor = self.get_hf_processor() - image_processor = hf_processor.image_processor - - image_size = self.get_image_size_with_most_features() - num_patches = self.get_num_patches( - image_width=image_size.width, - image_height=image_size.height, - size=image_processor.size, - min_patches=image_processor.min_patches, - max_patches=image_processor.max_patches, - ) - - img_patches_per_tile = (hf_processor.img_size // - hf_processor.patch_size)**2 - - return num_patches * img_patches_per_tile - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index a1f20ea4..dde78ee5 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -406,13 +406,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - def get_num_image_tokens(self) -> int: hf_config = self.get_hf_config() return hf_config.num_query_tokens diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index d46ae532..fb2f4b67 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -64,13 +64,6 @@ class ChameleonProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - def get_num_image_tokens(self) -> int: processor = self.get_hf_processor() return processor.image_seq_length diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index dc3aa9cb..153054e5 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -30,9 +30,6 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): ) -> int: return self.get_patch_grid_length()**2 + 1 - def get_max_image_tokens(self) -> int: - return self.get_patch_grid_length()**2 + 1 - def get_image_size(self) -> int: return self.vision_config.image_size diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 03d5be29..951185bc 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -168,20 +168,6 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo): image_width=x[1], image_height=x[0])) return ImageSize(width=width, height=height) - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - num_images = mm_counts.get("image", 0) - max_image_size = self.get_image_size_with_most_features() - max_image_tokens = self.get_num_image_tokens( - image_height=max_image_size.height, - image_width=max_image_size.width, - cropping=num_images <= 2) - - return {"image": max_image_tokens} - class DeepseekVL2DummyInputsBuilder( BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 62fd0939..56572bd5 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -764,17 +764,10 @@ class Florence2ProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_max_image_tokens(self) -> int: + def get_num_image_tokens(self) -> int: processor_config = self.ctx.get_hf_image_processor_config() return processor_config["image_seq_length"] - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - class Florence2DummyInputsBuilder( BaseDummyInputsBuilder[Florence2ProcessingInfo]): @@ -871,7 +864,7 @@ class Florence2MultiModalProcessor( ) -> Sequence[PromptUpdate]: hf_config = self.info.get_hf_config() pad_token_id = hf_config.pad_token_id - num_image_tokens = self.info.get_max_image_tokens() + num_image_tokens = self.info.get_num_image_tokens() image_tokens = [pad_token_id] * num_image_tokens return [ diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index c0a0f572..5fc6bb84 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -80,13 +80,6 @@ class FuyuProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_image_feature_grid_size( self, *, @@ -129,14 +122,6 @@ class FuyuProcessingInfo(BaseProcessingInfo): return ImageSize(width=image_processor.size["width"], height=image_processor.size["height"]) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) - class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 93d0aa30..34d856f4 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -68,13 +68,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def _resolve_image_kwargs( self, processor: Gemma3Processor, @@ -228,15 +221,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): # Result in the max possible feature size (h:w = max_num_crops:1) return ImageSize(height=50 * max_num_crops, width=50) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 6d7b760d..02954eec 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -431,13 +431,6 @@ class GLM4VProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_feature_tokens()} - def get_num_image_tokens(self) -> int: hf_config = self.get_hf_config() vision_config = hf_config.vision_config diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index f975a19a..15e126b0 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): **kwargs, ) - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - max_tokens_one_image = self.get_max_image_tokens(use_msac=None) - if mm_counts.get("image", 0) <= 1: - max_tokens_per_image = max_tokens_one_image - else: - max_tokens_per_image = self.get_max_image_tokens(use_msac=False) - - return {"image": max_tokens_per_image} - def get_num_image_tokens( self, *, @@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): use_msac=use_msac, ) - def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - use_msac=use_msac, - ) - class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] ): diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index ec02d1c8..655db1c8 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -97,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def _resize_output_size(self, *, height: int, @@ -287,15 +280,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): height=image_processor.size["longest_edge"], ) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] ): diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 7fd628fa..08741b3a 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -458,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_num_image_tokens( self, *, @@ -480,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): image_height=image_height, ) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 95165500..5804cb44 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -137,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def _apply_feature_select_strategy( self, strategy: str, diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 6fc4c187..281c9c0e 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -61,22 +61,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"video": 1} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - target_width, target_height = self.get_image_size_with_most_features() - - max_video_tokens = self.get_num_video_tokens( - image_width=target_width, - image_height=target_height, - num_frames=self.get_num_frames_with_most_features( - seq_len, mm_counts), - ) - - return {"video": max_video_tokens} - def get_image_size_with_most_features(self) -> ImageSize: vision_encoder_info = self.get_vision_encoder_info() width = height = vision_encoder_info.get_image_size() diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 5fbd27b9..f6256771 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return { - "image": self.get_max_image_tokens(), - "video": self.get_max_video_tokens(seq_len, mm_counts), - } - # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # with additional logic afterwards taken from LlavaOnevisionProcessor def _get_num_unpadded_features( diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index a4fb0cb1..8bb41a10 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -142,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {**super().get_supported_mm_limits(), "audio": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return { - **super().get_mm_max_tokens_per_item(seq_len, mm_counts), - "audio": - self.get_max_audio_tokens(), - } - def get_audio_placeholder( self, audio_lens: int, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 12b5364c..87c69021 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -346,18 +346,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): return mm_limits - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - mm_max_tokens = {"image": self.get_max_image_tokens()} - if self.get_model_version() == (2, 6): - mm_max_tokens["video"] = self.get_max_video_tokens( - seq_len, mm_counts) - - return mm_max_tokens - def get_slice_image_placeholder( self, image_size: ImageSize, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 67c0e2ec..d2c600fe 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -162,13 +162,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_num_image_tokens( self, *, @@ -186,14 +179,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo): width = height = vision_encoder_info.get_image_size() return ImageSize(width=width, height=height) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) - _I = TypeVar("_I", bound=BaseLlavaProcessingInfo) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index d332b17f..b61e42f3 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -106,16 +106,6 @@ class MllamaProcessingInfo(BaseProcessingInfo): image_size = self.get_hf_config().vision_config.image_size return calc_token_per_chunk(image_size) - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - vision_config = self.get_hf_config().vision_config - token_per_chunk = self.get_token_per_chunk_from_config() - mm_max_tokens = vision_config.max_num_tiles * token_per_chunk - return {"image": mm_max_tokens} - def get_num_tiles_per_image(self, image_height: int, image_width: int) -> int: vision_config = self.get_hf_config().vision_config diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 17171f82..4f709751 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -498,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): image_processor = self.get_hf_processor().image_processor return image_processor.max_patches - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - vision_config = self.get_hf_config().vision_config - patch_per_chunk = self.get_patch_per_chunk(vision_config) - num_patches = self.get_max_num_tiles() + 1 - - return {"image": patch_per_chunk * num_patches} - def get_image_size_with_most_features(self) -> ImageSize: vision_config = self.get_hf_config().vision_config image_size = vision_config.image_size @@ -516,14 +505,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): return ImageSize(height=self.get_max_num_tiles() * image_size, width=image_size) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) - class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ): diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index a7551e61..d896431b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1164,13 +1164,6 @@ class MolmoProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_num_image_tokens( self, *, @@ -1195,15 +1188,6 @@ class MolmoProcessingInfo(BaseProcessingInfo): return extra + joint - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 274163ac..ae8eee45 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -13,7 +13,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs) -from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, + MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptIndexTargets, PromptInsertion, PromptUpdate, @@ -72,16 +73,18 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": 1} - def get_mm_max_tokens_per_item( + def get_num_image_tokens( self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - - def get_num_image_tokens(self) -> int: + *, + image_width: int, + image_height: int, + ) -> int: vision_encoder_info = self.get_vision_encoder_info() - return vision_encoder_info.get_max_image_tokens() + + return vision_encoder_info.get_num_image_tokens( + image_width=image_width, + image_height=image_height, + ) class PaliGemmaDummyInputsBuilder( @@ -148,12 +151,30 @@ class PaliGemmaMultiModalProcessor( image_token_id = hf_config.image_token_index tokenizer = self.info.get_tokenizer() - num_image_tokens = self.info.get_num_image_tokens() - image_tokens = [image_token_id] * num_image_tokens bos_token_id = tokenizer.bos_token_id assert isinstance(bos_token_id, int) + def get_insertion(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems)) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + image_tokens = [image_token_id] * num_image_tokens + + return PromptUpdateDetails.select_token_id( + image_tokens + [bos_token_id], + embed_token_id=image_token_id, + ) + # Paligemma 1 and 2 have different tokenizer.add_bos_token # Insert *n + after for Paligemma 1 # Insert *n + for Paligemma 2 @@ -162,10 +183,7 @@ class PaliGemmaMultiModalProcessor( modality="image", target=PromptIndexTargets.prefix( [bos_token_id] if tokenizer.add_bos_token else []), - insertion=PromptUpdateDetails.select_token_id( - image_tokens + [bos_token_id], - embed_token_id=image_token_id, - ), + insertion=get_insertion, ) ] diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 344f348c..cce700f0 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -321,21 +321,6 @@ class Phi3VProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - target_width, target_height = self.get_image_size_with_most_features() - - max_image_tokens = self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - - return {"image": max_image_tokens} - def get_num_image_tokens( self, *, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 328d5271..fdd342cc 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -167,13 +167,6 @@ class PixtralProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_vision_config( self, processor: Optional[PixtralProcessorAdapter] = None, @@ -207,14 +200,6 @@ class PixtralProcessingInfo(BaseProcessingInfo): return ImageSize(width=max_image_size, height=max_image_size) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - ) - class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): @@ -938,14 +923,6 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): ) return ncols * nrows - def get_max_image_tokens(self) -> int: - image_size = self.get_image_size() - - return self.get_num_image_tokens( - image_width=image_size, - image_height=image_size, - ) - def get_image_size(self) -> int: return self.vision_config.image_size diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index a69c0fc5..e3a93e95 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -45,9 +45,6 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: - return {"image": 0} - class PrithviGeoSpatialMAEInputBuilder( BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 9f2593fc..ba4646f5 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -109,17 +109,6 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - hf_config = self.get_hf_config() - max_source_positions = hf_config.audio_config.max_source_positions - max_output_lengths = (max_source_positions - 2) // 2 + 1 - - return {"audio": max_output_lengths} - class Qwen2AudioDummyInputsBuilder( BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f93654d0..23f27e7e 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -818,16 +818,6 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return { - "image": self.get_max_image_tokens(), - "video": self.get_max_video_tokens(seq_len, mm_counts), - } - def _get_vision_info( self, *, diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 2e941f3b..403d47a3 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -530,13 +530,6 @@ class QwenVLProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_num_image_tokens()} - def get_num_image_tokens(self) -> int: hf_config = self.get_hf_config() vision_config = hf_config.visual diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index cecad9e8..75fcf540 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -33,9 +33,6 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): ) -> int: return self.get_patch_grid_length()**2 - def get_max_image_tokens(self) -> int: - return self.get_patch_grid_length()**2 - def get_image_size(self) -> int: return self.vision_config.image_size diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index a8460a2e..09a212a9 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -459,13 +459,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"image": self.get_max_image_tokens()} - def get_num_image_tokens( self, *, @@ -481,15 +474,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): image_height=image_height, ) - def get_max_image_tokens(self) -> int: - target_width, target_height = self.get_image_size_with_most_features() - - return self.get_num_image_tokens( - image_width=target_width, - image_height=target_height, - processor=None, - ) - def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 6e9d1526..3ff5a051 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -2,7 +2,6 @@ # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py """PyTorch Ultravox model.""" -import math from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union @@ -107,17 +106,6 @@ class UltravoxProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - feature_extractor = self.get_feature_extractor() - max_audio_tokens = math.ceil(feature_extractor.chunk_length * - _AUDIO_TOKENS_PER_SECOND) - - return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE} - class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] ): diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 347f5149..05e3b3f3 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -33,10 +33,6 @@ class VisionEncoderInfo(ABC, Generic[_C]): ) -> int: raise NotImplementedError - @abstractmethod - def get_max_image_tokens(self) -> int: - raise NotImplementedError - @abstractmethod def get_image_size(self) -> int: raise NotImplementedError diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 7751f96d..341e22a4 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -538,16 +538,9 @@ class WhisperProcessingInfo(BaseProcessingInfo): assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor - def get_max_audio_tokens(self) -> int: + def get_num_audio_tokens(self) -> int: return self.get_hf_config().max_source_positions - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - return {"audio": self.get_max_audio_tokens()} - class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): @@ -630,7 +623,7 @@ class WhisperMultiModalProcessor( hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: - num_tokens = self.info.get_max_audio_tokens() + num_tokens = self.info.get_num_audio_tokens() return [ PromptReplacement( modality="audio", diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 64f657db..fefeefd2 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1034,21 +1034,6 @@ class BaseProcessingInfo: """ raise NotImplementedError - @abstractmethod - def get_mm_max_tokens_per_item( - self, - seq_len: int, - mm_counts: Mapping[str, int], - ) -> Mapping[str, int]: - """ - Get the maximum possible number of tokens per data item - for each modality. - - The dictionary returned by this method should have the same - keys as that returned by :meth:`get_supported_mm_limits`. - """ - raise NotImplementedError - _I = TypeVar("_I", bound=BaseProcessingInfo) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index ec3625f2..7efe8644 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -68,7 +68,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ) -> ProcessorInputs: """ Build the input which, after processing, results in - :code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens. + the maximum possible number of placeholder tokens. """ raise NotImplementedError @@ -152,8 +152,11 @@ class MultiModalProfiler(Generic[_I]): def _get_dummy_mm_inputs( self, seq_len: int, - mm_counts: Mapping[str, int], + mm_counts: Optional[Mapping[str, int]] = None, ) -> MultiModalInputs: + if mm_counts is None: + mm_counts = self.get_mm_limits() + factory = self.dummy_inputs processor_inputs = factory.get_dummy_processor_inputs( seq_len, mm_counts) @@ -164,53 +167,23 @@ class MultiModalProfiler(Generic[_I]): hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, ) - def get_and_validate_mm_inputs( + def _get_mm_num_tokens( self, - seq_len: int, - mm_counts: Optional[Mapping[str, int]] = None, - ) -> tuple[MultiModalInputs, Mapping[str, int]]: - if mm_counts is None: - mm_counts = self.get_mm_limits() - - info = self.processing_info - mm_max_tokens_per_item = info.get_mm_max_tokens_per_item( - seq_len, mm_counts) - - if mm_counts.keys() - mm_max_tokens_per_item.keys(): - raise AssertionError( - "The keys returned by `get_supported_mm_limits` " - f"({set(mm_counts.keys())}) should be a subset of those " - "returned by `get_mm_max_tokens_per_item` " - f"({set(mm_max_tokens_per_item.keys())})") - - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + mm_inputs: MultiModalInputs, + ) -> Mapping[str, int]: placeholders_by_modality = mm_inputs["mm_placeholders"] - total_placeholders_by_modality = { + return { modality: sum(item.get_num_embeds() for item in placeholders) for modality, placeholders in placeholders_by_modality.items() } - expected_placeholders_by_modality = { - modality: mm_max_tokens_per_item[modality] * mm_counts[modality] - for modality in placeholders_by_modality - } - if total_placeholders_by_modality != expected_placeholders_by_modality: - raise AssertionError( - f"The processed dummy data has a total of " - f"{total_placeholders_by_modality} placeholder tokens, which " - f"is not the expected {expected_placeholders_by_modality} " - "tokens.") - return mm_inputs, total_placeholders_by_modality def get_encoder_dummy_data( self, seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, ) -> DummyEncoderData: - ( - mm_inputs, - total_placeholders_by_modality, - ) = self.get_and_validate_mm_inputs(seq_len, mm_counts) + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) # For encoder-decoder models, use encoder prompt token ids instead of @@ -232,7 +205,7 @@ class MultiModalProfiler(Generic[_I]): " is too short " "to hold the multi-modal embeddings in the worst case " f"({total_len} tokens in total, out of which " - f"{total_placeholders_by_modality} are reserved for " + f"{self._get_mm_num_tokens(mm_inputs)} are reserved for " "multi-modal embeddings). This may cause certain " "multi-modal inputs to fail during inference, even when " "the input text is short. To avoid this, you should " @@ -246,10 +219,7 @@ class MultiModalProfiler(Generic[_I]): seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, ) -> DummyDecoderData: - ( - mm_inputs, - total_placeholders_by_modality, - ) = self.get_and_validate_mm_inputs(seq_len, mm_counts) + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) prompt_token_ids = mm_inputs["prompt_token_ids"] total_len = len(prompt_token_ids) @@ -263,7 +233,7 @@ class MultiModalProfiler(Generic[_I]): "is too short " "to hold the multi-modal embeddings in the worst case " f"({total_len} tokens in total, out of which " - f"{total_placeholders_by_modality} are reserved for " + f"{self._get_mm_num_tokens(mm_inputs)} are reserved for " "multi-modal embeddings). This may cause certain " "multi-modal inputs to fail during inference, even when " "the input text is short. To avoid this, you should " @@ -278,3 +248,12 @@ class MultiModalProfiler(Generic[_I]): multi_modal_data=mm_inputs["mm_kwargs"], multi_modal_placeholders=mm_inputs["mm_placeholders"], ) + + def get_mm_max_tokens( + self, + seq_len: int, + mm_counts: Optional[Mapping[str, int]] = None, + ) -> Mapping[str, int]: + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) + + return self._get_mm_num_tokens(mm_inputs) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 4f41fa08..eafa28d6 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -258,10 +258,16 @@ class MultiModalRegistry: """ if self.has_processor(model_config): processor = self.create_processor(model_config, disable_cache=True) + profiler = MultiModalProfiler(processor) + seq_len = model_config.max_model_len mm_limits = self.get_mm_limits_per_prompt(model_config) - return processor.info.get_mm_max_tokens_per_item( - seq_len, mm_limits) + + return profiler.get_mm_max_tokens( + seq_len, + {modality: 1 + for modality in mm_limits}, + ) return { key: plugin.get_max_multimodal_tokens(model_config)