[VLM] Move supported limits and max tokens to merged multi-modal processor (#11669)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung 2025-01-01 23:44:42 +08:00 committed by GitHub
parent 73001445fb
commit a115ac46b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 351 additions and 361 deletions

View File

@ -4,7 +4,7 @@ from typing import Optional
import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputContext, InputProcessingContext
from vllm.inputs import InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from .....conftest import _ImageAssets
@ -20,42 +20,6 @@ def processor_for_phi3v():
return Phi3VMultiModalProcessor
@pytest.fixture()
def get_max_phi3v_image_tokens():
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
return get_max_phi3v_image_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
(4, 781),
(16, 2653),
])
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since
# this is testing the mapper directly. In practice, the processor kwargs
# are wrapped in a closure when calling the max tokens func. We explicitly
# do NOT use the mm_processor_kwargs in the model context here to ensure
# that the max image tokens implementation is referencing a mix of the
# kwargs to the function and the original mm_processor_kwargs in case
# values are somehow updated and end up in a bad state.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
actual_max_tokens = get_max_phi3v_image_tokens(
InputContext(ctx.model_config),
num_crops=num_crops,
)
assert expected_max_tokens == actual_max_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"num_crops,expected_toks_per_img",
@ -77,6 +41,7 @@ def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, Tuple
import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputContext, InputProcessingContext
from vllm.inputs import InputProcessingContext
from .....conftest import _ImageAssets
from ....utils import build_model_context
@ -22,39 +22,6 @@ def processor_for_qwen2_vl():
return Qwen2VLMultiModalProcessor
@pytest.fixture()
def get_max_qwen2_vl_image_tokens():
from vllm.model_executor.models.qwen2_vl import (
get_max_qwen2_vl_image_tokens)
return get_max_qwen2_vl_image_tokens
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 16384),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 324),
])
@pytest.mark.parametrize("model", [MODEL])
def test_qwen2_vl_max_image_tokens(
get_max_qwen2_vl_image_tokens,
model: str,
mm_processor_kwargs: Dict[str, Any],
expected_max_tokens: int,
):
"""Ensure that the max token calc handles min/max pixels properly."""
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
)
actual_max_tokens = get_max_qwen2_vl_image_tokens(
InputContext(ctx.model_config), **mm_processor_kwargs)
assert actual_max_tokens == expected_max_tokens
@pytest.mark.parametrize(
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
({}, 1426, (5704, 1176)),
@ -82,6 +49,7 @@ def test_processor_override(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)

View File

@ -538,6 +538,11 @@ def _test_processing_cache_correctness(
else:
hf_overrides = {}
limit_mm_per_prompt = {
modality: 3 if supports_multi else 1
for modality, supports_multi in modalities.items()
}
model_config = ModelConfig(
model_id,
task="auto",
@ -548,6 +553,7 @@ def _test_processing_cache_correctness(
dtype="float16",
revision=None,
hf_overrides=hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
@ -580,18 +586,14 @@ def _test_processing_cache_correctness(
min_wh=128,
max_wh=256),
"audio":
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
}
input_max_count = {
modality: 3 if supports_multi else 1
for modality, supports_multi in modalities.items()
partial(_rand_audio, rng, min_len=512, max_len=1024, sr=16000),
}
for batch_idx in range(num_batches):
mm_data = {
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(input_max_count[k]))]
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
for k in modalities
}

View File

@ -331,13 +331,7 @@ class InputRegistry:
trust_remote_code=model_config.trust_remote_code,
)
processor = mm_registry.create_processor(model_config, tokenizer)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
model_config)
dummy_data = processor.get_dummy_data(seq_len, mm_counts,
mm_max_tokens)
dummy_data = processor.get_dummy_data(seq_len)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:

View File

@ -1,5 +1,5 @@
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
@ -9,7 +9,6 @@ from transformers import BatchFeature, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -87,8 +86,8 @@ class AriaVisionModel(nn.Module):
def forward(
self,
pixel_values: torch.Tensor,
pixel_mask: Optional[torch.BoolTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]:
pixel_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
vit_oup = self.vision_model(
@ -100,7 +99,8 @@ class AriaVisionModel(nn.Module):
return vit_oup, image_atts
def _create_patch_attention_mask(self, pixel_mask):
def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None:
return None
@ -115,7 +115,8 @@ class AriaVisionModel(nn.Module):
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
def _create_image_attention_mask(self, patch_attention_mask):
def _create_image_attention_mask(
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
if patch_attention_mask is None:
return None
@ -125,13 +126,13 @@ class AriaVisionModel(nn.Module):
class FFN(nn.Module):
def __init__(self, embed_dim, ff_dim, output_dim):
def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
super().__init__()
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
self.act = get_act_fn("gelu_new")
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear_in(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_out(hidden_states)
@ -140,7 +141,7 @@ class FFN(nn.Module):
class CrossAttention(nn.Module):
def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
@ -149,12 +150,16 @@ class CrossAttention(nn.Module):
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(drop_out_rate)
self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim)
def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
def forward(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
@ -169,11 +174,7 @@ class CrossAttention(nn.Module):
attn_output = attn_output.permute(1, 0, 2)
if add_residual:
attn_output = hidden_states + self.dropout(
self.linear(attn_output))
else:
attn_output = self.dropout(self.linear(attn_output))
attn_output = self.linear(attn_output)
return attn_output
@ -201,14 +202,14 @@ class AriaProjector(nn.Module):
def __init__(
self,
patch_to_query_dict,
embed_dim,
num_heads,
kv_dim,
ff_dim,
output_dim,
norm_layer=nn.LayerNorm,
):
patch_to_query_dict: dict[int, int],
embed_dim: int,
num_heads: int,
kv_dim: int,
ff_dim: int,
output_dim: int,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
) -> None:
super().__init__()
self.patch_to_query_dict = patch_to_query_dict
self.embed_dim = embed_dim
@ -224,7 +225,11 @@ class AriaProjector(nn.Module):
self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim)
def forward(self, x, attn_mask=None):
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs = x.shape[0]
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
@ -442,13 +447,18 @@ def build_mm_projector(config: PretrainedConfig):
)
def get_max_aria_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@ -468,13 +478,13 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
hf_config = self.ctx.get_hf_config()
image_token_id = hf_config.image_token_index
max_image_tokens = get_max_aria_image_tokens(self.ctx)
num_image_tokens = self._get_num_image_tokens()
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
replacement=[image_token_id] * num_image_tokens,
)
]
@ -504,7 +514,6 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
"""

View File

@ -9,7 +9,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2Processor,
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -18,7 +17,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@ -398,15 +396,17 @@ class Blip2QFormerModel(nn.Module):
return sequence_output
def get_max_blip2_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor)
@ -427,7 +427,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
max_image_tokens = get_max_blip2_image_tokens(self.ctx)
max_image_tokens = self._get_num_image_tokens()
return [
PromptReplacement(
@ -480,7 +480,6 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):

View File

@ -11,7 +11,6 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -31,7 +30,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@ -43,11 +41,6 @@ from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
CHAMELEON_IMAGE_SEQ_LENGTH = 1024
class ChameleonImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
@ -55,14 +48,17 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
def get_max_chameleon_image_tokens(ctx: InputContext):
return CHAMELEON_IMAGE_SEQ_LENGTH
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def _get_num_image_tokens(self) -> int:
processor = self._get_hf_processor()
return processor.image_seq_length
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor)
@ -88,7 +84,7 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH,
processor.image_token * self._get_num_image_tokens(),
processor.image_end_token,
]),
)
@ -98,12 +94,15 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
config = self.ctx.get_hf_config(ChameleonConfig)
width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH,
height=CHAMELEON_CROP_SIZE_HEIGHT,
self._get_dummy_images(width=width,
height=height,
num_images=num_images)
}
@ -902,7 +901,6 @@ class ChameleonModel(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
@ -931,9 +929,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
self.model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
CHAMELEON_CROP_SIZE_WIDTH)
vq_config: ChameleonVQVAEConfig = self.config.vq_config
expected_dims = (3, vq_config.resolution, vq_config.resolution)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:

View File

@ -25,7 +25,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
@ -34,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@ -48,9 +47,6 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
class FuyuImagePatchInputs(TypedDict):
type: Literal["image_patches"]
@ -67,43 +63,49 @@ class FuyuImagePatchInputs(TypedDict):
"""
def _get_fuyu_num_image_tokens(
image_height: int,
image_width: int,
) -> Tuple[int, int]:
"""
Calculate the number of image tokens needed for a given image size.
The expected Fuyu image prompts can be expressed as:
.. code-block::
(image_token * ncols + newline_token) * nrows
Args:
image_size: Tuple[int, int] - `(width, height)` of the image
Returns:
ncols: int - number of image tokens in `x` direction
nrows: int - number of image tokens in `y` direction
"""
ncols = math.ceil(image_width / 30)
nrows = math.ceil(image_height / 30)
return ncols, nrows
def get_max_fuyu_image_tokens(ctx: InputContext):
ncols, nrows = _get_fuyu_num_image_tokens(
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
return (ncols + 1) * nrows
class FuyuMultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def _get_image_target_size(self) -> ImageSize:
processor = self._get_hf_processor()
image_processor: FuyuImageProcessor = processor.image_processor
target_size = image_processor.size
return ImageSize(width=target_size["width"],
height=target_size["height"])
def _get_image_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
target_width, target_height = self._get_image_target_size()
if not (image_width <= target_width and image_height <= target_height):
height_scale_factor = target_height / image_height
width_scale_factor = target_width / image_width
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
image_height = int(image_height * optimal_scale_factor)
image_width = int(image_width * optimal_scale_factor)
ncols = math.ceil(image_width / 30)
nrows = math.ceil(image_height / 30)
return ncols, nrows
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
target_width, target_height = self._get_image_target_size()
max_ncols, max_nrows = self._get_image_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def _get_hf_processor(self) -> FuyuProcessor:
return self.ctx.get_hf_processor(FuyuProcessor)
@ -166,28 +168,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int)
hf_processor = self._get_hf_processor()
image_processor: FuyuImageProcessor = hf_processor.image_processor
target_size = image_processor.size
target_height, target_width = (target_size["height"],
target_size["width"])
def get_replacement_fuyu(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
width, height = image_size.width, image_size.height
if not (width <= target_width and height <= target_height):
height_scale_factor = target_height / height
width_scale_factor = target_width / width
optimal_scale_factor = min(height_scale_factor,
width_scale_factor)
height = int(height * optimal_scale_factor)
width = int(width * optimal_scale_factor)
ncols, nrows = _get_fuyu_num_image_tokens(
image_width=width,
image_height=height,
ncols, nrows = self._get_image_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
@ -225,12 +212,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = self._get_image_target_size()
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
@ -240,7 +228,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):

View File

@ -119,6 +119,12 @@ def get_max_llava_image_tokens(ctx: InputContext):
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": get_max_llava_image_tokens(self.ctx)}
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
@ -324,7 +330,6 @@ def init_vision_tower_for_llava(
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
@ -649,7 +654,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass

View File

@ -23,7 +23,6 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -306,25 +305,32 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
def get_max_phi3v_image_tokens(
ctx: InputContext,
*,
num_crops: Optional[int] = None,
) -> int:
hf_processor_mm_kwargs = {}
if num_crops:
hf_processor_mm_kwargs["num_crops"] = num_crops
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
return processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
processor = self._get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
height=image_height,
)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
max_image_tokens = self._get_num_image_tokens(
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return {"image": max_image_tokens}
def _get_hf_processor(
self,
*,
@ -332,6 +338,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
) -> ProcessorMixin:
if num_crops is not None:
return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor()
def _call_hf_processor(
@ -375,7 +382,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
tokenizer = self._get_tokenizer()
bos_token_id = tokenizer.bos_token_id
@ -385,9 +391,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
num_tokens = self._get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
@ -467,7 +473,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(

View File

@ -33,13 +33,12 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@ -80,15 +79,18 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return feat_lengths, output_lengths
def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
hf_config = ctx.get_hf_config(Qwen2AudioConfig)
max_source_position = hf_config.audio_config.max_source_positions
output_lengths = (max_source_position - 2) // 2 + 1
return output_lengths
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
def _get_hf_processor(
self,
*,
@ -157,11 +159,21 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
audio_output_lengths = []
else:
assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lengths = _get_feat_extract_output_lengths(
_, audio_output_lens = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1))
audio_output_lengths = audio_output_lens.tolist()
def get_replacement_qwen2_audio(item_idx: int):
return [placeholder] * audio_output_lengths[item_idx]
num_placeholders = audio_output_lengths[item_idx]
if num_placeholders == 0:
audios = mm_items.get_items("audio", AudioProcessorItems)
audio = audios.get(item_idx)
raise ValueError(
f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model")
return [placeholder] * num_placeholders
return [
PromptReplacement(
@ -171,6 +183,14 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
)
]
def _always_apply_prompt_replacements(self) -> bool:
# HF never applies prompt replacements, so we have to do it ourselves
# _find_placeholders may incorrectly think that HF has already performed
# processing for multi-audio input when the input audios are short
# (the corresponding placeholders may take up fewer tokens than
# the number of audio items)
return True
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
@ -192,8 +212,6 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_qwen2_audio_audio_tokens)
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):

View File

@ -40,7 +40,6 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU
@ -650,8 +649,9 @@ def _get_vision_info(
width: int,
min_pixels: int,
max_pixels: int,
*,
do_resize: bool = True,
data_type_key: str = "image",
modality: str = "image",
mm_count: int = 1,
):
"""Get information (resized height / width and number of vision tokens)
@ -671,11 +671,12 @@ def _get_vision_info(
else:
resized_height, resized_width = height, width
if data_type_key == "image":
if modality == "image":
grid_t = mm_count
else:
assert data_type_key == "video"
elif modality == "video":
grid_t = max(mm_count // temporal_patch_size, 1)
else:
raise ValueError(f"Modality {modality} is not supported")
grid_h = resized_height // patch_size
grid_w = resized_width // patch_size
@ -691,41 +692,11 @@ def _get_image_processor(hf_processor: Qwen2VLProcessor):
return image_processor
def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
data_type_key: str,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None) -> int:
hf_config = ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config
hf_processor = ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = _get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info(
vision_config,
height=9999999,
width=9999999,
min_pixels=min_pixels or image_processor.min_pixels,
max_pixels=max_pixels or image_processor.max_pixels,
data_type_key=data_type_key,
)
return max_llm_image_tokens
get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="image")
get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="video")
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data)
self.modality = modality
super().__init__(data, modality)
grid_thw = data[f"{modality}_grid_thw"]
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
@ -734,9 +705,6 @@ class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
for i in range(len(grid_thw))
]
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int:
return len(self.data[f"{self.modality}_grid_thw"])
@ -792,6 +760,32 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def _get_max_mm_tokens(self, modality: str) -> int:
hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info(
vision_config,
height=9999999,
width=9999999,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
modality=modality,
)
return max_llm_image_tokens
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {
"image": self._get_max_mm_tokens("image"),
"video": self._get_max_mm_tokens("video"),
}
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
@ -908,9 +902,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_qwen2_vl_video_tokens)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):

View File

@ -2,7 +2,7 @@
"""PyTorch Ultravox model."""
import math
from functools import cached_property, lru_cache
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
@ -17,7 +17,6 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -58,23 +57,18 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
UltravoxAudioEmbeddingInputs]
@lru_cache
def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
return WhisperFeatureExtractor.from_pretrained(model_id)
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
hf_config = ctx.get_hf_config(UltravoxConfig)
return cached_feature_extractor(hf_config.audio_model_id)
def get_ultravox_max_audio_tokens(ctx: InputContext):
feature_extractor = whisper_feature_extractor(ctx)
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
return {"audio": max_audio_tokens}
def _get_hf_processor(
self,
*,
@ -322,8 +316,6 @@ class ModifiedWhisperEncoder(WhisperEncoder):
return hidden_states
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens)
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):

View File

@ -21,10 +21,15 @@ _I = TypeVar("_I")
class ModalityDataItems(ABC, Generic[_T, _I]):
def __init__(self, data: _T) -> None:
def __init__(self, data: _T, modality: str) -> None:
super().__init__()
self.data = data
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r}, "
f"len={len(self)})")
def __len__(self) -> int:
return self.get_count()
@ -64,14 +69,6 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
def __init__(self, data: Sequence[_T], modality: str) -> None:
super().__init__(data)
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int:
return len(self.data)
@ -87,14 +84,6 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def __init__(self, data: NestedTensors, modality: str) -> None:
super().__init__(data)
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int:
return len(self.data)
@ -222,22 +211,13 @@ class MultiModalDataParser:
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
Args:
max_mm_counts (Mapping[str, int]): The maximum allowed number of items
belonging to each modality. This effectively sets a hard limit over
`--limit-mm-per-prompt`.
target_sr (float, optional): Enables automatic resampling of audio
items to the model's expected sampling rate.
"""
def __init__(
self,
*,
max_mm_counts: Mapping[str, int] = {},
target_sr: Optional[float] = None,
) -> None:
def __init__(self, *, target_sr: Optional[float] = None) -> None:
super().__init__()
self.max_mm_counts = max_mm_counts
self.target_sr = target_sr
def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
@ -345,7 +325,6 @@ class MultiModalDataParser:
def parse_mm_data(self,
mm_data: MultiModalDataDict) -> MultiModalDataItems:
max_mm_counts = self.max_mm_counts
subparsers = self._get_subparsers()
mm_items = MultiModalDataItems()
@ -353,16 +332,6 @@ class MultiModalDataParser:
if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}")
modality_items = subparsers[k](v)
if k in max_mm_counts:
max_count = max_mm_counts[k]
if len(modality_items) > max_count:
raise ValueError(
f"This model supports at most {max_count} {k} items "
f"per prompt, but {len(modality_items)} {k} items "
"were given or set as its limit_mm_per_prompt.")
mm_items[k] = modality_items
mm_items[k] = subparsers[k](v)
return mm_items

View File

@ -624,6 +624,29 @@ class BaseMultiModalProcessor(ABC):
) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
def _get_data_parser(self) -> MultiModalDataParser:
"""
Construct a data parser to preprocess multi-modal data items
@ -653,7 +676,18 @@ class BaseMultiModalProcessor(ABC):
before passing them to :meth:`_get_hf_mm_data`.
"""
parser = self._get_data_parser()
return parser.parse_mm_data(mm_data)
mm_items = parser.parse_mm_data(mm_data)
mm_limits = self.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items():
limit = mm_limits.get(modality, 1)
if len(items) > limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but passed {len(items)} "
f"{modality} items in the same prompt.")
return mm_items
@abstractmethod
def _get_mm_fields_config(
@ -901,6 +935,17 @@ class BaseMultiModalProcessor(ABC):
return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
def _always_apply_prompt_replacements(self) -> bool:
"""
A flag which can be overridden so that
:meth:`_apply_prompt_replacements` is always called even if we
detect that HF has performed processing via :meth:`_find_placeholders`.
This is useful in cases where :meth:`_find_placeholders` cannot be
reliably used to detect whether HF has performed processing or not.
"""
return False
def _apply_prompt_replacements(
self,
token_ids: list[int],
@ -995,7 +1040,7 @@ class BaseMultiModalProcessor(ABC):
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
mm_item_counts)
if all_placeholders:
if all_placeholders and not self._always_apply_prompt_replacements():
tokenizer = self._get_tokenizer()
prompt_text = _decode(tokenizer, prompt_ids)
else:
@ -1009,10 +1054,27 @@ class BaseMultiModalProcessor(ABC):
mm_item_counts,
)
mm_placeholders = {
modality: [item.to_range() for item in items]
for modality, items in full_groupby_modality(all_placeholders)
}
mm_placeholders = dict[str, list[PlaceholderRange]]()
err_suffix = ("This suggests a problem with your implementation of "
"the merged multi-modal processor for this model, "
"particularly in the `_get_prompt_replacements` method.")
for modality, placeholders in full_groupby_modality(all_placeholders):
if modality not in mm_items:
raise AssertionError(
f"Expected no placeholders for {modality=}, "
f"but found {placeholders=}. Input items: {mm_items}"
f"\n{err_suffix}")
if len(placeholders) != len(mm_items[modality]):
raise AssertionError(
f"Expected length of {placeholders=} for {modality=} "
f"to equal that of input items: {mm_items[modality]}"
f"\n{err_suffix}")
mm_placeholders[modality] = [
item.to_range() for item in placeholders
]
return MultiModalInputsV2(
type="multimodal",
@ -1063,15 +1125,38 @@ class BaseMultiModalProcessor(ABC):
"""
raise NotImplementedError
def get_dummy_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_max_tokens: Mapping[str, int],
) -> DummyData:
def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]:
mm_limit_per_prompt = self.ctx.get_mm_config().limit_per_prompt
supported_mm_limits = self.get_supported_mm_limits()
mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1)
for modality in supported_mm_limits
}
for modality, supported_limit in supported_mm_limits.items():
limit = mm_limits[modality]
if supported_limit is not None and supported_limit < limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but this model only supports "
f"at most {supported_limit} {modality} items.")
return mm_limits
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_counts = self._get_and_validate_dummy_mm_counts()
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item()
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
processor_inputs = self._get_dummy_mm_inputs(mm_counts)
mm_inputs = self.apply(
prompt_text=processor_inputs.prompt_text,
@ -1087,7 +1172,7 @@ class BaseMultiModalProcessor(ABC):
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens[modality]
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:

View File

@ -15,6 +15,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import BaseMultiModalProcessor, ProcessingCache
from .utils import cached_get_tokenizer
from .video import VideoPlugin
if TYPE_CHECKING:
@ -219,6 +220,10 @@ class MultiModalRegistry:
Note:
This is currently directly used only in V1.
"""
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer)
return processor.get_mm_max_tokens_per_item()
return {
key: plugin.get_max_multimodal_tokens(model_config)