[VLM] Move supported limits and max tokens to merged multi-modal processor (#11669)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung 2025-01-01 23:44:42 +08:00 committed by GitHub
parent 73001445fb
commit a115ac46b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 351 additions and 361 deletions

View File

@ -4,7 +4,7 @@ from typing import Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.inputs import InputContext, InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
@ -20,42 +20,6 @@ def processor_for_phi3v():
return Phi3VMultiModalProcessor return Phi3VMultiModalProcessor
@pytest.fixture()
def get_max_phi3v_image_tokens():
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
return get_max_phi3v_image_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
(4, 781),
(16, 2653),
])
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
num_crops: int, expected_max_tokens: int):
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
# NOTE: mm_processor_kwargs on the context in this test is unused, since
# this is testing the mapper directly. In practice, the processor kwargs
# are wrapped in a closure when calling the max tokens func. We explicitly
# do NOT use the mm_processor_kwargs in the model context here to ensure
# that the max image tokens implementation is referencing a mix of the
# kwargs to the function and the original mm_processor_kwargs in case
# values are somehow updated and end up in a bad state.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
actual_max_tokens = get_max_phi3v_image_tokens(
InputContext(ctx.model_config),
num_crops=num_crops,
)
assert expected_max_tokens == actual_max_tokens
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_crops,expected_toks_per_img", "num_crops,expected_toks_per_img",
@ -77,6 +41,7 @@ def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
model_name=model, model_name=model,
tokenizer_name=model, tokenizer_name=model,
trust_remote_code=True, trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx = InputProcessingContext(ctx.model_config, tokenizer)

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, Tuple
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.inputs import InputContext, InputProcessingContext from vllm.inputs import InputProcessingContext
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
from ....utils import build_model_context from ....utils import build_model_context
@ -22,39 +22,6 @@ def processor_for_qwen2_vl():
return Qwen2VLMultiModalProcessor return Qwen2VLMultiModalProcessor
@pytest.fixture()
def get_max_qwen2_vl_image_tokens():
from vllm.model_executor.models.qwen2_vl import (
get_max_qwen2_vl_image_tokens)
return get_max_qwen2_vl_image_tokens
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
({}, 16384),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 324),
])
@pytest.mark.parametrize("model", [MODEL])
def test_qwen2_vl_max_image_tokens(
get_max_qwen2_vl_image_tokens,
model: str,
mm_processor_kwargs: Dict[str, Any],
expected_max_tokens: int,
):
"""Ensure that the max token calc handles min/max pixels properly."""
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
mm_processor_kwargs=None,
)
actual_max_tokens = get_max_qwen2_vl_image_tokens(
InputContext(ctx.model_config), **mm_processor_kwargs)
assert actual_max_tokens == expected_max_tokens
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [ "mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
({}, 1426, (5704, 1176)), ({}, 1426, (5704, 1176)),
@ -82,6 +49,7 @@ def test_processor_override(
model_name=model, model_name=model,
tokenizer_name=model, tokenizer_name=model,
mm_processor_kwargs=None, mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx = InputProcessingContext(ctx.model_config, tokenizer)

View File

@ -538,6 +538,11 @@ def _test_processing_cache_correctness(
else: else:
hf_overrides = {} hf_overrides = {}
limit_mm_per_prompt = {
modality: 3 if supports_multi else 1
for modality, supports_multi in modalities.items()
}
model_config = ModelConfig( model_config = ModelConfig(
model_id, model_id,
task="auto", task="auto",
@ -548,6 +553,7 @@ def _test_processing_cache_correctness(
dtype="float16", dtype="float16",
revision=None, revision=None,
hf_overrides=hf_overrides, hf_overrides=hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
) )
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
@ -580,18 +586,14 @@ def _test_processing_cache_correctness(
min_wh=128, min_wh=128,
max_wh=256), max_wh=256),
"audio": "audio":
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000), partial(_rand_audio, rng, min_len=512, max_len=1024, sr=16000),
}
input_max_count = {
modality: 3 if supports_multi else 1
for modality, supports_multi in modalities.items()
} }
for batch_idx in range(num_batches): for batch_idx in range(num_batches):
mm_data = { mm_data = {
k: k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(input_max_count[k]))] for _ in range(rng.randint(limit_mm_per_prompt[k]))]
for k in modalities for k in modalities
} }

View File

@ -331,13 +331,7 @@ class InputRegistry:
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
processor = mm_registry.create_processor(model_config, tokenizer) processor = mm_registry.create_processor(model_config, tokenizer)
dummy_data = processor.get_dummy_data(seq_len)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
model_config)
dummy_data = processor.get_dummy_data(seq_len, mm_counts,
mm_max_tokens)
else: else:
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
if is_encoder_data: if is_encoder_data:

View File

@ -1,5 +1,5 @@
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
Union) TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -9,7 +9,6 @@ from transformers import BatchFeature, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -87,8 +86,8 @@ class AriaVisionModel(nn.Module):
def forward( def forward(
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
pixel_mask: Optional[torch.BoolTensor] = None, pixel_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
patch_attention_mask = self._create_patch_attention_mask(pixel_mask) patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
vit_oup = self.vision_model( vit_oup = self.vision_model(
@ -100,7 +99,8 @@ class AriaVisionModel(nn.Module):
return vit_oup, image_atts return vit_oup, image_atts
def _create_patch_attention_mask(self, pixel_mask): def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None: if pixel_mask is None:
return None return None
@ -115,7 +115,8 @@ class AriaVisionModel(nn.Module):
) )
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
def _create_image_attention_mask(self, patch_attention_mask): def _create_image_attention_mask(
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
if patch_attention_mask is None: if patch_attention_mask is None:
return None return None
@ -125,13 +126,13 @@ class AriaVisionModel(nn.Module):
class FFN(nn.Module): class FFN(nn.Module):
def __init__(self, embed_dim, ff_dim, output_dim): def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
super().__init__() super().__init__()
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False) self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False) self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
self.act = get_act_fn("gelu_new") self.act = get_act_fn("gelu_new")
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear_in(hidden_states) hidden_states, _ = self.linear_in(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_out(hidden_states) hidden_states, _ = self.linear_out(hidden_states)
@ -140,7 +141,7 @@ class FFN(nn.Module):
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
@ -149,12 +150,16 @@ class CrossAttention(nn.Module):
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim) self.linear = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(drop_out_rate)
self.layer_norm = nn.LayerNorm(embed_dim) self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim) self.ln_kv = nn.LayerNorm(kv_dim)
def forward(self, x, hidden_states, attn_mask=None, add_residual=False): def forward(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2) query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
@ -169,11 +174,7 @@ class CrossAttention(nn.Module):
attn_output = attn_output.permute(1, 0, 2) attn_output = attn_output.permute(1, 0, 2)
if add_residual: attn_output = self.linear(attn_output)
attn_output = hidden_states + self.dropout(
self.linear(attn_output))
else:
attn_output = self.dropout(self.linear(attn_output))
return attn_output return attn_output
@ -201,14 +202,14 @@ class AriaProjector(nn.Module):
def __init__( def __init__(
self, self,
patch_to_query_dict, patch_to_query_dict: dict[int, int],
embed_dim, embed_dim: int,
num_heads, num_heads: int,
kv_dim, kv_dim: int,
ff_dim, ff_dim: int,
output_dim, output_dim: int,
norm_layer=nn.LayerNorm, norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
): ) -> None:
super().__init__() super().__init__()
self.patch_to_query_dict = patch_to_query_dict self.patch_to_query_dict = patch_to_query_dict
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -224,7 +225,11 @@ class AriaProjector(nn.Module):
self.ln_ffn = norm_layer(embed_dim) self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim) self.ffn = FFN(embed_dim, ff_dim, output_dim)
def forward(self, x, attn_mask=None): def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs = x.shape[0] bs = x.shape[0]
queries = self.query.unsqueeze(0).repeat(bs, 1, 1) queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
@ -442,13 +447,18 @@ def build_mm_projector(config: PretrainedConfig):
) )
def get_max_aria_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaMultiModalProcessor(BaseMultiModalProcessor): class AriaMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
@ -468,13 +478,13 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
hf_config = self.ctx.get_hf_config() hf_config = self.ctx.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
max_image_tokens = get_max_aria_image_tokens(self.ctx) num_image_tokens = self._get_num_image_tokens()
return [ return [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[image_token_id], target=[image_token_id],
replacement=[image_token_id] * max_image_tokens, replacement=[image_token_id] * num_image_tokens,
) )
] ]
@ -504,7 +514,6 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
) )
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
""" """

View File

@ -9,7 +9,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2Processor,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -18,7 +17,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
@ -398,15 +396,17 @@ class Blip2QFormerModel(nn.Module):
return sequence_output return sequence_output
def get_max_blip2_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens
class Blip2MultiModalProcessor(BaseMultiModalProcessor): class Blip2MultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return MultiModalDataParser(max_mm_counts={"image": 1}) return {"image": 1}
def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> Blip2Processor: def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor) return self.ctx.get_hf_processor(Blip2Processor)
@ -427,7 +427,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
max_image_tokens = get_max_blip2_image_tokens(self.ctx) max_image_tokens = self._get_num_image_tokens()
return [ return [
PromptReplacement( PromptReplacement(
@ -480,7 +480,6 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
) )
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):

View File

@ -11,7 +11,6 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@ -31,7 +30,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
@ -43,11 +41,6 @@ from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
CHAMELEON_IMAGE_SEQ_LENGTH = 1024
class ChameleonImagePixelInputs(TypedDict): class ChameleonImagePixelInputs(TypedDict):
type: Literal["pixel_values"] type: Literal["pixel_values"]
@ -55,14 +48,17 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
def get_max_chameleon_image_tokens(ctx: InputContext):
return CHAMELEON_IMAGE_SEQ_LENGTH
class ChameleonMultiModalProcessor(BaseMultiModalProcessor): class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return MultiModalDataParser(max_mm_counts={"image": 1}) return {"image": 1}
def _get_num_image_tokens(self) -> int:
processor = self._get_hf_processor()
return processor.image_seq_length
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> ChameleonProcessor: def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor) return self.ctx.get_hf_processor(ChameleonProcessor)
@ -88,7 +84,7 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
target="<image>", target="<image>",
replacement="".join([ replacement="".join([
processor.image_start_token, processor.image_start_token,
processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH, processor.image_token * self._get_num_image_tokens(),
processor.image_end_token, processor.image_end_token,
]), ]),
) )
@ -98,12 +94,15 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
config = self.ctx.get_hf_config(ChameleonConfig)
width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { mm_data = {
"image": "image":
self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH, self._get_dummy_images(width=width,
height=CHAMELEON_CROP_SIZE_HEIGHT, height=height,
num_images=num_images) num_images=num_images)
} }
@ -902,7 +901,6 @@ class ChameleonModel(nn.Module):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
@ -931,9 +929,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
vq_config: ChameleonVQVAEConfig = self.config.vq_config
expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT, expected_dims = (3, vq_config.resolution, vq_config.resolution)
CHAMELEON_CROP_SIZE_WIDTH)
actual_dims = tuple(data.shape[1:]) actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims: if actual_dims != expected_dims:

View File

@ -25,7 +25,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
@ -34,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser from vllm.multimodal.parse import ImageProcessorItems, ImageSize
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
@ -48,9 +47,6 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019 _NEWLINE_TOKEN_ID = 71019
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
class FuyuImagePatchInputs(TypedDict): class FuyuImagePatchInputs(TypedDict):
type: Literal["image_patches"] type: Literal["image_patches"]
@ -67,43 +63,49 @@ class FuyuImagePatchInputs(TypedDict):
""" """
def _get_fuyu_num_image_tokens(
image_height: int,
image_width: int,
) -> Tuple[int, int]:
"""
Calculate the number of image tokens needed for a given image size.
The expected Fuyu image prompts can be expressed as:
.. code-block::
(image_token * ncols + newline_token) * nrows
Args:
image_size: Tuple[int, int] - `(width, height)` of the image
Returns:
ncols: int - number of image tokens in `x` direction
nrows: int - number of image tokens in `y` direction
"""
ncols = math.ceil(image_width / 30)
nrows = math.ceil(image_height / 30)
return ncols, nrows
def get_max_fuyu_image_tokens(ctx: InputContext):
ncols, nrows = _get_fuyu_num_image_tokens(
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
return (ncols + 1) * nrows
class FuyuMultiModalProcessor(BaseMultiModalProcessor): class FuyuMultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return MultiModalDataParser(max_mm_counts={"image": 1}) return {"image": 1}
def _get_image_target_size(self) -> ImageSize:
processor = self._get_hf_processor()
image_processor: FuyuImageProcessor = processor.image_processor
target_size = image_processor.size
return ImageSize(width=target_size["width"],
height=target_size["height"])
def _get_image_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
target_width, target_height = self._get_image_target_size()
if not (image_width <= target_width and image_height <= target_height):
height_scale_factor = target_height / image_height
width_scale_factor = target_width / image_width
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
image_height = int(image_height * optimal_scale_factor)
image_width = int(image_width * optimal_scale_factor)
ncols = math.ceil(image_width / 30)
nrows = math.ceil(image_height / 30)
return ncols, nrows
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
target_width, target_height = self._get_image_target_size()
max_ncols, max_nrows = self._get_image_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def _get_hf_processor(self) -> FuyuProcessor: def _get_hf_processor(self) -> FuyuProcessor:
return self.ctx.get_hf_processor(FuyuProcessor) return self.ctx.get_hf_processor(FuyuProcessor)
@ -166,28 +168,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
eot_token_id = tokenizer.bos_token_id eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int) assert isinstance(eot_token_id, int)
hf_processor = self._get_hf_processor()
image_processor: FuyuImageProcessor = hf_processor.image_processor
target_size = image_processor.size
target_height, target_width = (target_size["height"],
target_size["width"])
def get_replacement_fuyu(item_idx: int): def get_replacement_fuyu(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
width, height = image_size.width, image_size.height
if not (width <= target_width and height <= target_height):
height_scale_factor = target_height / height
width_scale_factor = target_width / width
optimal_scale_factor = min(height_scale_factor,
width_scale_factor)
height = int(height * optimal_scale_factor) ncols, nrows = self._get_image_grid_size(
width = int(width * optimal_scale_factor) image_width=image_size.width,
image_height=image_size.height,
ncols, nrows = _get_fuyu_num_image_tokens(
image_width=width,
image_height=height,
) )
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows + return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
@ -225,12 +212,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
target_width, target_height = self._get_image_target_size()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { mm_data = {
"image": "image":
self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH, self._get_dummy_images(width=target_width,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, height=target_height,
num_images=num_images) num_images=num_images)
} }
@ -240,7 +228,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
) )
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):

View File

@ -119,6 +119,12 @@ def get_max_llava_image_tokens(ctx: InputContext):
class LlavaMultiModalProcessor(BaseMultiModalProcessor): class LlavaMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": get_max_llava_image_tokens(self.ctx)}
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]: def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor)) return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
@ -324,7 +330,6 @@ def init_vision_tower_for_llava(
raise NotImplementedError(msg) raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes # BitandBytes specific attributes
@ -649,7 +654,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
# To use this model, please use # To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
class MantisForConditionalGeneration(LlavaForConditionalGeneration): class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass pass

View File

@ -23,7 +23,6 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -306,25 +305,32 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline return image_features_hd_newline
def get_max_phi3v_image_tokens(
ctx: InputContext,
*,
num_crops: Optional[int] = None,
) -> int:
hf_processor_mm_kwargs = {}
if num_crops:
hf_processor_mm_kwargs["num_crops"] = num_crops
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
return processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor): class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
processor = self._get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
height=image_height,
)
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
max_image_tokens = self._get_num_image_tokens(
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return {"image": max_image_tokens}
def _get_hf_processor( def _get_hf_processor(
self, self,
*, *,
@ -332,6 +338,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
) -> ProcessorMixin: ) -> ProcessorMixin:
if num_crops is not None: if num_crops is not None:
return self.ctx.get_hf_processor(num_crops=num_crops) return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor() return self.ctx.get_hf_processor()
def _call_hf_processor( def _call_hf_processor(
@ -375,7 +382,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
bos_token_id = tokenizer.bos_token_id bos_token_id = tokenizer.bos_token_id
@ -385,9 +391,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size( num_tokens = self._get_num_image_tokens(
width=image_size.width, image_width=image_size.width,
height=image_size.height, image_height=image_size.height,
) )
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id] return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
@ -467,7 +473,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return result return result
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(

View File

@ -33,13 +33,12 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
@ -80,15 +79,18 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return feat_lengths, output_lengths return feat_lengths, output_lengths
def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
hf_config = ctx.get_hf_config(Qwen2AudioConfig)
max_source_position = hf_config.audio_config.max_source_positions
output_lengths = (max_source_position - 2) // 2 + 1
return output_lengths
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
def _get_hf_processor( def _get_hf_processor(
self, self,
*, *,
@ -157,11 +159,21 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
audio_output_lengths = [] audio_output_lengths = []
else: else:
assert isinstance(feature_attention_mask, torch.Tensor) assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lengths = _get_feat_extract_output_lengths( _, audio_output_lens = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1)) feature_attention_mask.sum(-1))
audio_output_lengths = audio_output_lens.tolist()
def get_replacement_qwen2_audio(item_idx: int): def get_replacement_qwen2_audio(item_idx: int):
return [placeholder] * audio_output_lengths[item_idx] num_placeholders = audio_output_lengths[item_idx]
if num_placeholders == 0:
audios = mm_items.get_items("audio", AudioProcessorItems)
audio = audios.get(item_idx)
raise ValueError(
f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model")
return [placeholder] * num_placeholders
return [ return [
PromptReplacement( PromptReplacement(
@ -171,6 +183,14 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
) )
] ]
def _always_apply_prompt_replacements(self) -> bool:
# HF never applies prompt replacements, so we have to do it ourselves
# _find_placeholders may incorrectly think that HF has already performed
# processing for multi-audio input when the input audios are short
# (the corresponding placeholders may take up fewer tokens than
# the number of audio items)
return True
def _get_dummy_mm_inputs( def _get_dummy_mm_inputs(
self, self,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
@ -192,8 +212,6 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
) )
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_qwen2_audio_audio_tokens)
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):

View File

@ -40,7 +40,6 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.inputs import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
@ -650,8 +649,9 @@ def _get_vision_info(
width: int, width: int,
min_pixels: int, min_pixels: int,
max_pixels: int, max_pixels: int,
*,
do_resize: bool = True, do_resize: bool = True,
data_type_key: str = "image", modality: str = "image",
mm_count: int = 1, mm_count: int = 1,
): ):
"""Get information (resized height / width and number of vision tokens) """Get information (resized height / width and number of vision tokens)
@ -671,11 +671,12 @@ def _get_vision_info(
else: else:
resized_height, resized_width = height, width resized_height, resized_width = height, width
if data_type_key == "image": if modality == "image":
grid_t = mm_count grid_t = mm_count
else: elif modality == "video":
assert data_type_key == "video"
grid_t = max(mm_count // temporal_patch_size, 1) grid_t = max(mm_count // temporal_patch_size, 1)
else:
raise ValueError(f"Modality {modality} is not supported")
grid_h = resized_height // patch_size grid_h = resized_height // patch_size
grid_w = resized_width // patch_size grid_w = resized_width // patch_size
@ -691,41 +692,11 @@ def _get_image_processor(hf_processor: Qwen2VLProcessor):
return image_processor return image_processor
def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
data_type_key: str,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None) -> int:
hf_config = ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config
hf_processor = ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = _get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info(
vision_config,
height=9999999,
width=9999999,
min_pixels=min_pixels or image_processor.min_pixels,
max_pixels=max_pixels or image_processor.max_pixels,
data_type_key=data_type_key,
)
return max_llm_image_tokens
get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="image")
get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="video")
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]): dict[str, torch.Tensor]]):
def __init__(self, data: dict, modality: str) -> None: def __init__(self, data: dict, modality: str) -> None:
super().__init__(data) super().__init__(data, modality)
self.modality = modality
grid_thw = data[f"{modality}_grid_thw"] grid_thw = data[f"{modality}_grid_thw"]
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist() slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
@ -734,9 +705,6 @@ class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
for i in range(len(grid_thw)) for i in range(len(grid_thw))
] ]
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int: def get_count(self) -> int:
return len(self.data[f"{self.modality}_grid_thw"]) return len(self.data[f"{self.modality}_grid_thw"])
@ -792,6 +760,32 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def _get_max_mm_tokens(self, modality: str) -> int:
hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info(
vision_config,
height=9999999,
width=9999999,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
modality=modality,
)
return max_llm_image_tokens
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {
"image": self._get_max_mm_tokens("image"),
"video": self._get_max_mm_tokens("video"),
}
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser() return Qwen2MultiModalDataParser()
@ -908,9 +902,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
) )
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_qwen2_vl_video_tokens)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP): SupportsLoRA, SupportsPP):

View File

@ -2,7 +2,7 @@
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import math import math
from functools import cached_property, lru_cache from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
@ -17,7 +17,6 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -58,23 +57,18 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
UltravoxAudioEmbeddingInputs] UltravoxAudioEmbeddingInputs]
@lru_cache
def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
return WhisperFeatureExtractor.from_pretrained(model_id)
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
hf_config = ctx.get_hf_config(UltravoxConfig)
return cached_feature_extractor(hf_config.audio_model_id)
def get_ultravox_max_audio_tokens(ctx: InputContext):
feature_extractor = whisper_feature_extractor(ctx)
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
class UltravoxMultiModalProcessor(BaseMultiModalProcessor): class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
return {"audio": max_audio_tokens}
def _get_hf_processor( def _get_hf_processor(
self, self,
*, *,
@ -322,8 +316,6 @@ class ModifiedWhisperEncoder(WhisperEncoder):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_ultravox_max_audio_tokens)
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):

View File

@ -21,10 +21,15 @@ _I = TypeVar("_I")
class ModalityDataItems(ABC, Generic[_T, _I]): class ModalityDataItems(ABC, Generic[_T, _I]):
def __init__(self, data: _T) -> None: def __init__(self, data: _T, modality: str) -> None:
super().__init__() super().__init__()
self.data = data self.data = data
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r}, "
f"len={len(self)})")
def __len__(self) -> int: def __len__(self) -> int:
return self.get_count() return self.get_count()
@ -64,14 +69,6 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
def __init__(self, data: Sequence[_T], modality: str) -> None:
super().__init__(data)
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int: def get_count(self) -> int:
return len(self.data) return len(self.data)
@ -87,14 +84,6 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]): class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def __init__(self, data: NestedTensors, modality: str) -> None:
super().__init__(data)
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int: def get_count(self) -> int:
return len(self.data) return len(self.data)
@ -222,22 +211,13 @@ class MultiModalDataParser:
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`. Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
Args: Args:
max_mm_counts (Mapping[str, int]): The maximum allowed number of items
belonging to each modality. This effectively sets a hard limit over
`--limit-mm-per-prompt`.
target_sr (float, optional): Enables automatic resampling of audio target_sr (float, optional): Enables automatic resampling of audio
items to the model's expected sampling rate. items to the model's expected sampling rate.
""" """
def __init__( def __init__(self, *, target_sr: Optional[float] = None) -> None:
self,
*,
max_mm_counts: Mapping[str, int] = {},
target_sr: Optional[float] = None,
) -> None:
super().__init__() super().__init__()
self.max_mm_counts = max_mm_counts
self.target_sr = target_sr self.target_sr = target_sr
def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]: def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
@ -345,7 +325,6 @@ class MultiModalDataParser:
def parse_mm_data(self, def parse_mm_data(self,
mm_data: MultiModalDataDict) -> MultiModalDataItems: mm_data: MultiModalDataDict) -> MultiModalDataItems:
max_mm_counts = self.max_mm_counts
subparsers = self._get_subparsers() subparsers = self._get_subparsers()
mm_items = MultiModalDataItems() mm_items = MultiModalDataItems()
@ -353,16 +332,6 @@ class MultiModalDataParser:
if k not in subparsers: if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}") raise ValueError(f"Unsupported modality: {k}")
modality_items = subparsers[k](v) mm_items[k] = subparsers[k](v)
if k in max_mm_counts:
max_count = max_mm_counts[k]
if len(modality_items) > max_count:
raise ValueError(
f"This model supports at most {max_count} {k} items "
f"per prompt, but {len(modality_items)} {k} items "
"were given or set as its limit_mm_per_prompt.")
mm_items[k] = modality_items
return mm_items return mm_items

View File

@ -624,6 +624,29 @@ class BaseMultiModalProcessor(ABC):
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs) return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
""" """
Construct a data parser to preprocess multi-modal data items Construct a data parser to preprocess multi-modal data items
@ -653,7 +676,18 @@ class BaseMultiModalProcessor(ABC):
before passing them to :meth:`_get_hf_mm_data`. before passing them to :meth:`_get_hf_mm_data`.
""" """
parser = self._get_data_parser() parser = self._get_data_parser()
return parser.parse_mm_data(mm_data) mm_items = parser.parse_mm_data(mm_data)
mm_limits = self.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items():
limit = mm_limits.get(modality, 1)
if len(items) > limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but passed {len(items)} "
f"{modality} items in the same prompt.")
return mm_items
@abstractmethod @abstractmethod
def _get_mm_fields_config( def _get_mm_fields_config(
@ -901,6 +935,17 @@ class BaseMultiModalProcessor(ABC):
return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls] return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
def _always_apply_prompt_replacements(self) -> bool:
"""
A flag which can be overridden so that
:meth:`_apply_prompt_replacements` is always called even if we
detect that HF has performed processing via :meth:`_find_placeholders`.
This is useful in cases where :meth:`_find_placeholders` cannot be
reliably used to detect whether HF has performed processing or not.
"""
return False
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
@ -995,7 +1040,7 @@ class BaseMultiModalProcessor(ABC):
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids, all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
mm_item_counts) mm_item_counts)
if all_placeholders: if all_placeholders and not self._always_apply_prompt_replacements():
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
prompt_text = _decode(tokenizer, prompt_ids) prompt_text = _decode(tokenizer, prompt_ids)
else: else:
@ -1009,10 +1054,27 @@ class BaseMultiModalProcessor(ABC):
mm_item_counts, mm_item_counts,
) )
mm_placeholders = { mm_placeholders = dict[str, list[PlaceholderRange]]()
modality: [item.to_range() for item in items] err_suffix = ("This suggests a problem with your implementation of "
for modality, items in full_groupby_modality(all_placeholders) "the merged multi-modal processor for this model, "
} "particularly in the `_get_prompt_replacements` method.")
for modality, placeholders in full_groupby_modality(all_placeholders):
if modality not in mm_items:
raise AssertionError(
f"Expected no placeholders for {modality=}, "
f"but found {placeholders=}. Input items: {mm_items}"
f"\n{err_suffix}")
if len(placeholders) != len(mm_items[modality]):
raise AssertionError(
f"Expected length of {placeholders=} for {modality=} "
f"to equal that of input items: {mm_items[modality]}"
f"\n{err_suffix}")
mm_placeholders[modality] = [
item.to_range() for item in placeholders
]
return MultiModalInputsV2( return MultiModalInputsV2(
type="multimodal", type="multimodal",
@ -1063,15 +1125,38 @@ class BaseMultiModalProcessor(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def get_dummy_data( def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]:
self, mm_limit_per_prompt = self.ctx.get_mm_config().limit_per_prompt
seq_len: int, supported_mm_limits = self.get_supported_mm_limits()
mm_counts: Mapping[str, int],
mm_max_tokens: Mapping[str, int], mm_limits = {
) -> DummyData: modality: mm_limit_per_prompt.get(modality, 1)
for modality in supported_mm_limits
}
for modality, supported_limit in supported_mm_limits.items():
limit = mm_limits[modality]
if supported_limit is not None and supported_limit < limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but this model only supports "
f"at most {supported_limit} {modality} items.")
return mm_limits
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
mm_counts = self._get_and_validate_dummy_mm_counts()
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item()
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
processor_inputs = self._get_dummy_mm_inputs(mm_counts) processor_inputs = self._get_dummy_mm_inputs(mm_counts)
mm_inputs = self.apply( mm_inputs = self.apply(
prompt_text=processor_inputs.prompt_text, prompt_text=processor_inputs.prompt_text,
@ -1087,7 +1172,7 @@ class BaseMultiModalProcessor(ABC):
for modality, placeholders in placeholders_by_modality.items() for modality, placeholders in placeholders_by_modality.items()
} }
expected_placeholders_by_modality = { expected_placeholders_by_modality = {
modality: mm_max_tokens[modality] modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality for modality in placeholders_by_modality
} }
if total_placeholders_by_modality != expected_placeholders_by_modality: if total_placeholders_by_modality != expected_placeholders_by_modality:

View File

@ -15,6 +15,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import BaseMultiModalProcessor, ProcessingCache from .processing import BaseMultiModalProcessor, ProcessingCache
from .utils import cached_get_tokenizer
from .video import VideoPlugin from .video import VideoPlugin
if TYPE_CHECKING: if TYPE_CHECKING:
@ -219,6 +220,10 @@ class MultiModalRegistry:
Note: Note:
This is currently directly used only in V1. This is currently directly used only in V1.
""" """
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer)
return processor.get_mm_max_tokens_per_item()
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)