[VLM] Remove BaseProcessingInfo.get_mm_max_tokens_per_item (#16408)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-11 00:06:58 +08:00 committed by GitHub
parent 7678fcd5b6
commit 83b824c8b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 104 additions and 677 deletions

View File

@ -121,17 +121,21 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": 1} return {"image": None, "video": 1}
``` ```
### Maximum number of placeholder feature tokens ## 3. Specify dummy inputs
Also, override the abstract method {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item` Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for
to return the maximum number of placeholder feature tokens per input item for each modality. HF processing as well as memory profiling.
When calling the model, the output embeddings from the visual encoder are assigned to the input positions ### For memory profiling
containing placeholder feature tokens. Therefore, the number of placeholder feature tokens should be equal
to the size of the output embeddings.
:::::{tab-set} Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`
::::{tab-item} Basic example: LLaVA to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of
the model so that vLLM can reserve the correct amount of memory for it.
Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed to maximize the number of output embeddings, which is the same number as placeholder feature tokens.
::::{tab-set}
:::{tab-item} Basic example: LLaVA
:sync: llava :sync: llava
Looking at the code of HF's `LlavaForConditionalGeneration`: Looking at the code of HF's `LlavaForConditionalGeneration`:
@ -240,7 +244,7 @@ def get_num_image_tokens(
``` ```
Notice that the number of image tokens doesn't depend on the image width and height. Notice that the number of image tokens doesn't depend on the image width and height.
So, we can calculate the maximum number of image tokens using any image size: We can simply use a dummy `image_size`:
```python ```python
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
@ -248,33 +252,35 @@ def get_image_size_with_most_features(self) -> ImageSize:
width = height = hf_config.image_size width = height = hf_config.image_size
return ImageSize(width=width, height=height) return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int: def get_dummy_processor_inputs(
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
```
And thus, we can override the method as:
```python
def get_mm_max_tokens_per_item(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Mapping[str, int]: ) -> ProcessorInputs:
return {"image": self.get_max_image_tokens()} num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
hf_config = self.get_hf_config()
target_width, target_height = self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
``` ```
:::{note}
Our [actual code](gh-file:vllm/model_executor/models/llava.py) is more abstracted to support vision encoders other than CLIP.
::: :::
:::: :::{tab-item} No input placeholders: Fuyu
::::{tab-item} Non-consecutive feature tokens: Fuyu
:sync: fuyu :sync: fuyu
Looking at the code of HF's `FuyuForCausalLM`: Looking at the code of HF's `FuyuForCausalLM`:
@ -394,188 +400,16 @@ num_patches_per_dim_w = image_width // patch_width
num_patches = num_patches_per_dim_h * num_patches_per_dim_w num_patches = num_patches_per_dim_h * num_patches_per_dim_w
``` ```
We can calculate this in vLLM using this code: These image patches correspond to placeholder tokens (`|SPEAKER|`). So, we just need to maximize the number of image patches. Since input images are first resized
to fit within `image_processor.size`, we can maximize the number of image patches by inputting an image with size equal to `image_processor.size`.
```python
def get_num_image_patches(
self,
*,
image_width: int,
image_height: int,
) -> int:
image_processor = self.get_image_processor()
target_width = image_processor.size["width"]
target_height = image_processor.size["height"]
patch_width = image_processor.patch_size["width"]
patch_height = image_processor.patch_size["height"]
if not (image_width <= target_width and image_height <= target_height):
height_scale_factor = target_height / image_height
width_scale_factor = target_width / image_width
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
image_height = int(image_height * optimal_scale_factor)
image_width = int(image_width * optimal_scale_factor)
ncols = math.ceil(image_width / patch_width)
nrows = math.ceil(image_height / patch_height)
return ncols * nrows
```
These image patches correspond to placeholder tokens (`|SPEAKER|`). However, the processor also
inserts newline tokens (`|NEWLINE|`) as shown here:
```python
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L654-L670
tensor_of_image_ids = torch.full(
[num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
)
patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
assert num_patches == patches.shape[0]
if variable_sized:
# Now terminate each line with |NEWLINE|.
tensor_of_image_ids = tensor_of_image_ids.reshape(-1, image_width // patch_width)
newline_ids = torch.full(
[tensor_of_image_ids.shape[0], 1],
image_newline_id,
dtype=torch.int32,
device=image_input.device,
)
tensor_of_image_ids = torch.cat([tensor_of_image_ids, newline_ids], dim=1)
tensor_of_image_ids = tensor_of_image_ids.reshape(-1)
```
So, the layout of tokens for an image is:
```
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
...
|SPEAKER||SPEAKER|...|SPEAKER||NEWLINE|
```
This makes the placeholder tokens non-consecutive in the prompt.
Since vLLM requires the feature tokens to be consecutive, **we also treat the newline tokens as feature tokens**.
So overall, the total number of feature tokens is
```python
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
image_processor = self.get_image_processor()
target_width = image_processor.size["width"]
target_height = image_processor.size["height"]
patch_width = image_processor.patch_size["width"]
patch_height = image_processor.patch_size["height"]
if not (image_width <= target_width and image_height <= target_height):
height_scale_factor = target_height / image_height
width_scale_factor = target_width / image_width
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
image_height = int(image_height * optimal_scale_factor)
image_width = int(image_width * optimal_scale_factor)
ncols = math.ceil(image_width / patch_width)
nrows = math.ceil(image_height / patch_height)
return (ncols + 1) * nrows
```
To calculate the maximum number of image tokens, recall that input images are first resized
to fit within `image_processor.size`. The maximum possible dimensions of the image before
being converted into patches is therefore equal to `image_processor.size`.
```python ```python
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size["width"], return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"]) height=image_processor.size["height"])
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
``` ```
And thus, we can override the method as:
```python
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
```
:::{note}
Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) returns `ncols` and `nrows` directly instead of the total token count.
This is because `ncols` and `nrows` are used to specify the layout of the feature tokens (as shown in Step 4 of this guide).
:::
::::
:::::
## 3. Specify dummy inputs
Then, inherit {class}`~vllm.multimodal.profiling.BaseDummyInputsBuilder` to construct dummy inputs for
HF processing as well as memory profiling.
### For memory profiling
Override the abstract method {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`
to construct dummy inputs for memory profiling. This dummy input should result in the worst-case memory usage of
the model so that vLLM can reserve the correct amount of memory for it.
Assuming that the memory usage increases with the number of tokens, the dummy input can be constructed based
on the code for {meth}`~vllm.multimodal.processing.BaseProcessingInfo.get_mm_max_tokens_per_item`.
::::{tab-set}
:::{tab-item} Basic example: LLaVA
:sync: llava
Making use of the `get_image_size_with_most_features` method implemented in Step 2:
```python
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
hf_config = self.get_hf_config()
target_width, target_height = self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
```
:::
:::{tab-item} No input placeholders: Fuyu
:sync: fuyu
Fuyu does not expect image placeholders in the inputs to HF processor, so Fuyu does not expect image placeholders in the inputs to HF processor, so
the dummy prompt text is empty regardless of the number of images. the dummy prompt text is empty regardless of the number of images.
Otherwise, the logic of this method is very similar to LLaVA: Otherwise, the logic of this method is very similar to LLaVA:

View File

@ -76,11 +76,6 @@ def test_processor_override(
if v == config.boi_token_index] if v == config.boi_token_index]
# patch sizes and masks # patch sizes and masks
patch_token_id = vocab[hf_processor.img_patch_token]
num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id)
mm_counts = {"image": num_imgs}
assert num_patches / num_imgs <= \
processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"]
num_patches_per_chunk = processor.info.get_patch_per_chunk( num_patches_per_chunk = processor.info.get_patch_per_chunk(
config.vision_config) config.vision_config)
assert prompt_token_ids.count(config.image_token_index) \ assert prompt_token_ids.count(config.image_token_index) \

View File

@ -408,13 +408,6 @@ class AriaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values()) return max(hf_config.projector_patch_to_query_dict.values())

View File

@ -117,31 +117,6 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
def get_image_processor(self) -> GotOcr2ImageProcessor: def get_image_processor(self) -> GotOcr2ImageProcessor:
return self.get_hf_processor().image_processor return self.get_hf_processor().image_processor
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
image_processor = hf_processor.image_processor
image_size = self.get_image_size_with_most_features()
num_patches = self.get_num_patches(
image_width=image_size.width,
image_height=image_size.height,
size=image_processor.size,
min_patches=image_processor.min_patches,
max_patches=image_processor.max_patches,
)
img_patches_per_tile = (hf_processor.img_size //
hf_processor.patch_size)**2
return num_patches * img_patches_per_tile
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}

View File

@ -406,13 +406,6 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
return hf_config.num_query_tokens return hf_config.num_query_tokens

View File

@ -64,13 +64,6 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
processor = self.get_hf_processor() processor = self.get_hf_processor()
return processor.image_seq_length return processor.image_seq_length

View File

@ -30,9 +30,6 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
) -> int: ) -> int:
return self.get_patch_grid_length()**2 + 1 return self.get_patch_grid_length()**2 + 1
def get_max_image_tokens(self) -> int:
return self.get_patch_grid_length()**2 + 1
def get_image_size(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size return self.vision_config.image_size

View File

@ -168,20 +168,6 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
image_width=x[1], image_height=x[0])) image_width=x[1], image_height=x[0]))
return ImageSize(width=width, height=height) return ImageSize(width=width, height=height)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
num_images = mm_counts.get("image", 0)
max_image_size = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_height=max_image_size.height,
image_width=max_image_size.width,
cropping=num_images <= 2)
return {"image": max_image_tokens}
class DeepseekVL2DummyInputsBuilder( class DeepseekVL2DummyInputsBuilder(
BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]): BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]):

View File

@ -764,17 +764,10 @@ class Florence2ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_max_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
processor_config = self.ctx.get_hf_image_processor_config() processor_config = self.ctx.get_hf_image_processor_config()
return processor_config["image_seq_length"] return processor_config["image_seq_length"]
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
class Florence2DummyInputsBuilder( class Florence2DummyInputsBuilder(
BaseDummyInputsBuilder[Florence2ProcessingInfo]): BaseDummyInputsBuilder[Florence2ProcessingInfo]):
@ -871,7 +864,7 @@ class Florence2MultiModalProcessor(
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
pad_token_id = hf_config.pad_token_id pad_token_id = hf_config.pad_token_id
num_image_tokens = self.info.get_max_image_tokens() num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [pad_token_id] * num_image_tokens image_tokens = [pad_token_id] * num_image_tokens
return [ return [

View File

@ -80,13 +80,6 @@ class FuyuProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_image_feature_grid_size( def get_image_feature_grid_size(
self, self,
*, *,
@ -129,14 +122,6 @@ class FuyuProcessingInfo(BaseProcessingInfo):
return ImageSize(width=image_processor.size["width"], return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"]) height=image_processor.size["height"])
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):

View File

@ -68,13 +68,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _resolve_image_kwargs( def _resolve_image_kwargs(
self, self,
processor: Gemma3Processor, processor: Gemma3Processor,
@ -228,15 +221,6 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
# Result in the max possible feature size (h:w = max_num_crops:1) # Result in the max possible feature size (h:w = max_num_crops:1)
return ImageSize(height=50 * max_num_crops, width=50) return ImageSize(height=50 * max_num_crops, width=50)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):

View File

@ -431,13 +431,6 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_feature_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config

View File

@ -412,19 +412,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
**kwargs, **kwargs,
) )
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
max_tokens_one_image = self.get_max_image_tokens(use_msac=None)
if mm_counts.get("image", 0) <= 1:
max_tokens_per_image = max_tokens_one_image
else:
max_tokens_per_image = self.get_max_image_tokens(use_msac=False)
return {"image": max_tokens_per_image}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
@ -442,16 +429,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
use_msac=use_msac, use_msac=use_msac,
) )
def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
use_msac=use_msac,
)
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
): ):

View File

@ -97,13 +97,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _resize_output_size(self, def _resize_output_size(self,
*, *,
height: int, height: int,
@ -287,15 +280,6 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
height=image_processor.size["longest_edge"], height=image_processor.size["longest_edge"],
) )
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
): ):

View File

@ -458,13 +458,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
@ -480,15 +473,6 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
image_height=image_height, image_height=image_height,
) )
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()

View File

@ -137,13 +137,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy( def _apply_feature_select_strategy(
self, self,
strategy: str, strategy: str,

View File

@ -61,22 +61,6 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1} return {"video": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
)
return {"video": max_video_tokens}
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info() vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size() width = height = vision_encoder_info.get_image_size()

View File

@ -101,16 +101,6 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len, mm_counts),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor # with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features( def _get_num_unpadded_features(

View File

@ -142,17 +142,6 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {**super().get_supported_mm_limits(), "audio": None} return {**super().get_supported_mm_limits(), "audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
**super().get_mm_max_tokens_per_item(seq_len, mm_counts),
"audio":
self.get_max_audio_tokens(),
}
def get_audio_placeholder( def get_audio_placeholder(
self, self,
audio_lens: int, audio_lens: int,

View File

@ -346,18 +346,6 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return mm_limits return mm_limits
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(
seq_len, mm_counts)
return mm_max_tokens
def get_slice_image_placeholder( def get_slice_image_placeholder(
self, self,
image_size: ImageSize, image_size: ImageSize,

View File

@ -162,13 +162,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
@ -186,14 +179,6 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
width = height = vision_encoder_info.get_image_size() width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height) return ImageSize(width=width, height=height)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo) _I = TypeVar("_I", bound=BaseLlavaProcessingInfo)

View File

@ -106,16 +106,6 @@ class MllamaProcessingInfo(BaseProcessingInfo):
image_size = self.get_hf_config().vision_config.image_size image_size = self.get_hf_config().vision_config.image_size
return calc_token_per_chunk(image_size) return calc_token_per_chunk(image_size)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
token_per_chunk = self.get_token_per_chunk_from_config()
mm_max_tokens = vision_config.max_num_tiles * token_per_chunk
return {"image": mm_max_tokens}
def get_num_tiles_per_image(self, image_height: int, def get_num_tiles_per_image(self, image_height: int,
image_width: int) -> int: image_width: int) -> int:
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config

View File

@ -498,17 +498,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
image_processor = self.get_hf_processor().image_processor image_processor = self.get_hf_processor().image_processor
return image_processor.max_patches return image_processor.max_patches
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
vision_config = self.get_hf_config().vision_config
patch_per_chunk = self.get_patch_per_chunk(vision_config)
num_patches = self.get_max_num_tiles() + 1
return {"image": patch_per_chunk * num_patches}
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size image_size = vision_config.image_size
@ -516,14 +505,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
return ImageSize(height=self.get_max_num_tiles() * image_size, return ImageSize(height=self.get_max_num_tiles() * image_size,
width=image_size) width=image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
): ):

View File

@ -1164,13 +1164,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
@ -1195,15 +1188,6 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return extra + joint return extra + joint
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()

View File

@ -13,7 +13,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs) MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptUpdate, PromptInsertion, PromptUpdate,
@ -72,16 +73,18 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item( def get_num_image_tokens(
self, self,
seq_len: int, *,
mm_counts: Mapping[str, int], image_width: int,
) -> Mapping[str, int]: image_height: int,
return {"image": self.get_num_image_tokens()} ) -> int:
def get_num_image_tokens(self) -> int:
vision_encoder_info = self.get_vision_encoder_info() vision_encoder_info = self.get_vision_encoder_info()
return vision_encoder_info.get_max_image_tokens()
return vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
)
class PaliGemmaDummyInputsBuilder( class PaliGemmaDummyInputsBuilder(
@ -148,12 +151,30 @@ class PaliGemmaMultiModalProcessor(
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
bos_token_id = tokenizer.bos_token_id bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int) assert isinstance(bos_token_id, int)
def get_insertion(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
image_tokens = [image_token_id] * num_image_tokens
return PromptUpdateDetails.select_token_id(
image_tokens + [bos_token_id],
embed_token_id=image_token_id,
)
# Paligemma 1 and 2 have different tokenizer.add_bos_token # Paligemma 1 and 2 have different tokenizer.add_bos_token
# Insert <image>*n + <bos> after <bos> for Paligemma 1 # Insert <image>*n + <bos> after <bos> for Paligemma 1
# Insert <image>*n + <bos> for Paligemma 2 # Insert <image>*n + <bos> for Paligemma 2
@ -162,10 +183,7 @@ class PaliGemmaMultiModalProcessor(
modality="image", modality="image",
target=PromptIndexTargets.prefix( target=PromptIndexTargets.prefix(
[bos_token_id] if tokenizer.add_bos_token else []), [bos_token_id] if tokenizer.add_bos_token else []),
insertion=PromptUpdateDetails.select_token_id( insertion=get_insertion,
image_tokens + [bos_token_id],
embed_token_id=image_token_id,
),
) )
] ]

View File

@ -321,21 +321,6 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
return {"image": max_image_tokens}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,

View File

@ -167,13 +167,6 @@ class PixtralProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_vision_config( def get_vision_config(
self, self,
processor: Optional[PixtralProcessorAdapter] = None, processor: Optional[PixtralProcessorAdapter] = None,
@ -207,14 +200,6 @@ class PixtralProcessingInfo(BaseProcessingInfo):
return ImageSize(width=max_image_size, height=max_image_size) return ImageSize(width=max_image_size, height=max_image_size)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
@ -938,14 +923,6 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
) )
return ncols * nrows return ncols * nrows
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size()
return self.get_num_image_tokens(
image_width=image_size,
image_height=image_size,
)
def get_image_size(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size return self.vision_config.image_size

View File

@ -45,9 +45,6 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": 0}
class PrithviGeoSpatialMAEInputBuilder( class PrithviGeoSpatialMAEInputBuilder(
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):

View File

@ -109,17 +109,6 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
class Qwen2AudioDummyInputsBuilder( class Qwen2AudioDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]): BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):

View File

@ -818,16 +818,6 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len, mm_counts),
}
def _get_vision_info( def _get_vision_info(
self, self,
*, *,

View File

@ -530,13 +530,6 @@ class QwenVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int: def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.visual vision_config = hf_config.visual

View File

@ -33,9 +33,6 @@ class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
) -> int: ) -> int:
return self.get_patch_grid_length()**2 return self.get_patch_grid_length()**2
def get_max_image_tokens(self) -> int:
return self.get_patch_grid_length()**2
def get_image_size(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size return self.vision_config.image_size

View File

@ -459,13 +459,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
@ -481,15 +474,6 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
image_height=image_height, image_height=image_height,
) )
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()

View File

@ -2,7 +2,6 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
@ -107,17 +106,6 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
return {"audio": max_audio_tokens * _MAX_ENCODER_BATCH_SIZE}
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
): ):

View File

@ -33,10 +33,6 @@ class VisionEncoderInfo(ABC, Generic[_C]):
) -> int: ) -> int:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_max_image_tokens(self) -> int:
raise NotImplementedError
@abstractmethod @abstractmethod
def get_image_size(self) -> int: def get_image_size(self) -> int:
raise NotImplementedError raise NotImplementedError

View File

@ -538,16 +538,9 @@ class WhisperProcessingInfo(BaseProcessingInfo):
assert isinstance(feature_extractor, WhisperFeatureExtractor) assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor return feature_extractor
def get_max_audio_tokens(self) -> int: def get_num_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions return self.get_hf_config().max_source_positions
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"audio": self.get_max_audio_tokens()}
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
@ -630,7 +623,7 @@ class WhisperMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
num_tokens = self.info.get_max_audio_tokens() num_tokens = self.info.get_num_audio_tokens()
return [ return [
PromptReplacement( PromptReplacement(
modality="audio", modality="audio",

View File

@ -1034,21 +1034,6 @@ class BaseProcessingInfo:
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
_I = TypeVar("_I", bound=BaseProcessingInfo) _I = TypeVar("_I", bound=BaseProcessingInfo)

View File

@ -68,7 +68,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Build the input which, after processing, results in Build the input which, after processing, results in
:code:`self.info.get_mm_max_tokens_per_item()` placeholder tokens. the maximum possible number of placeholder tokens.
""" """
raise NotImplementedError raise NotImplementedError
@ -152,8 +152,11 @@ class MultiModalProfiler(Generic[_I]):
def _get_dummy_mm_inputs( def _get_dummy_mm_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Optional[Mapping[str, int]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if mm_counts is None:
mm_counts = self.get_mm_limits()
factory = self.dummy_inputs factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs( processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts) seq_len, mm_counts)
@ -164,53 +167,23 @@ class MultiModalProfiler(Generic[_I]):
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
) )
def get_and_validate_mm_inputs( def _get_mm_num_tokens(
self, self,
seq_len: int, mm_inputs: MultiModalInputs,
mm_counts: Optional[Mapping[str, int]] = None, ) -> Mapping[str, int]:
) -> tuple[MultiModalInputs, Mapping[str, int]]:
if mm_counts is None:
mm_counts = self.get_mm_limits()
info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
seq_len, mm_counts)
if mm_counts.keys() - mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits` "
f"({set(mm_counts.keys())}) should be a subset of those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = { return {
modality: sum(item.get_num_embeds() for item in placeholders) modality: sum(item.get_num_embeds() for item in placeholders)
for modality, placeholders in placeholders_by_modality.items() for modality, placeholders in placeholders_by_modality.items()
} }
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
return mm_inputs, total_placeholders_by_modality
def get_encoder_dummy_data( def get_encoder_dummy_data(
self, self,
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None, mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyEncoderData: ) -> DummyEncoderData:
( mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
mm_inputs,
total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of # For encoder-decoder models, use encoder prompt token ids instead of
@ -232,7 +205,7 @@ class MultiModalProfiler(Generic[_I]):
" is too short " " is too short "
"to hold the multi-modal embeddings in the worst case " "to hold the multi-modal embeddings in the worst case "
f"({total_len} tokens in total, out of which " f"({total_len} tokens in total, out of which "
f"{total_placeholders_by_modality} are reserved for " f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
"multi-modal embeddings). This may cause certain " "multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when " "multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should " "the input text is short. To avoid this, you should "
@ -246,10 +219,7 @@ class MultiModalProfiler(Generic[_I]):
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None, mm_counts: Optional[Mapping[str, int]] = None,
) -> DummyDecoderData: ) -> DummyDecoderData:
( mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
mm_inputs,
total_placeholders_by_modality,
) = self.get_and_validate_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids) total_len = len(prompt_token_ids)
@ -263,7 +233,7 @@ class MultiModalProfiler(Generic[_I]):
"is too short " "is too short "
"to hold the multi-modal embeddings in the worst case " "to hold the multi-modal embeddings in the worst case "
f"({total_len} tokens in total, out of which " f"({total_len} tokens in total, out of which "
f"{total_placeholders_by_modality} are reserved for " f"{self._get_mm_num_tokens(mm_inputs)} are reserved for "
"multi-modal embeddings). This may cause certain " "multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when " "multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should " "the input text is short. To avoid this, you should "
@ -278,3 +248,12 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_data=mm_inputs["mm_kwargs"], multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=mm_inputs["mm_placeholders"], multi_modal_placeholders=mm_inputs["mm_placeholders"],
) )
def get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
) -> Mapping[str, int]:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs)

View File

@ -258,10 +258,16 @@ class MultiModalRegistry:
""" """
if self.has_processor(model_config): if self.has_processor(model_config):
processor = self.create_processor(model_config, disable_cache=True) processor = self.create_processor(model_config, disable_cache=True)
profiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config) mm_limits = self.get_mm_limits_per_prompt(model_config)
return processor.info.get_mm_max_tokens_per_item(
seq_len, mm_limits) return profiler.get_mm_max_tokens(
seq_len,
{modality: 1
for modality in mm_limits},
)
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)