[VLM] Move supported limits and max tokens to merged multi-modal processor (#11669)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
73001445fb
commit
a115ac46b5
@ -4,7 +4,7 @@ from typing import Optional
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputContext, InputProcessingContext
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
|
||||
|
||||
from .....conftest import _ImageAssets
|
||||
@ -20,42 +20,6 @@ def processor_for_phi3v():
|
||||
return Phi3VMultiModalProcessor
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_max_phi3v_image_tokens():
|
||||
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
|
||||
return get_max_phi3v_image_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
|
||||
(4, 781),
|
||||
(16, 2653),
|
||||
])
|
||||
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
|
||||
num_crops: int, expected_max_tokens: int):
|
||||
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
|
||||
# NOTE: mm_processor_kwargs on the context in this test is unused, since
|
||||
# this is testing the mapper directly. In practice, the processor kwargs
|
||||
# are wrapped in a closure when calling the max tokens func. We explicitly
|
||||
# do NOT use the mm_processor_kwargs in the model context here to ensure
|
||||
# that the max image tokens implementation is referencing a mix of the
|
||||
# kwargs to the function and the original mm_processor_kwargs in case
|
||||
# values are somehow updated and end up in a bad state.
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
actual_max_tokens = get_max_phi3v_image_tokens(
|
||||
InputContext(ctx.model_config),
|
||||
num_crops=num_crops,
|
||||
)
|
||||
|
||||
assert expected_max_tokens == actual_max_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"num_crops,expected_toks_per_img",
|
||||
@ -77,6 +41,7 @@ def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
|
@ -3,7 +3,7 @@ from typing import Any, Dict, Tuple
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputContext, InputProcessingContext
|
||||
from vllm.inputs import InputProcessingContext
|
||||
|
||||
from .....conftest import _ImageAssets
|
||||
from ....utils import build_model_context
|
||||
@ -22,39 +22,6 @@ def processor_for_qwen2_vl():
|
||||
return Qwen2VLMultiModalProcessor
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_max_qwen2_vl_image_tokens():
|
||||
from vllm.model_executor.models.qwen2_vl import (
|
||||
get_max_qwen2_vl_image_tokens)
|
||||
return get_max_qwen2_vl_image_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [
|
||||
({}, 16384),
|
||||
({
|
||||
MIN_PIXELS: 64**2,
|
||||
MAX_PIXELS: 512**2
|
||||
}, 324),
|
||||
])
|
||||
@pytest.mark.parametrize("model", [MODEL])
|
||||
def test_qwen2_vl_max_image_tokens(
|
||||
get_max_qwen2_vl_image_tokens,
|
||||
model: str,
|
||||
mm_processor_kwargs: Dict[str, Any],
|
||||
expected_max_tokens: int,
|
||||
):
|
||||
"""Ensure that the max token calc handles min/max pixels properly."""
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
actual_max_tokens = get_max_qwen2_vl_image_tokens(
|
||||
InputContext(ctx.model_config), **mm_processor_kwargs)
|
||||
assert actual_max_tokens == expected_max_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
|
||||
({}, 1426, (5704, 1176)),
|
||||
@ -82,6 +49,7 @@ def test_processor_override(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
ctx = InputProcessingContext(ctx.model_config, tokenizer)
|
||||
|
@ -538,6 +538,11 @@ def _test_processing_cache_correctness(
|
||||
else:
|
||||
hf_overrides = {}
|
||||
|
||||
limit_mm_per_prompt = {
|
||||
modality: 3 if supports_multi else 1
|
||||
for modality, supports_multi in modalities.items()
|
||||
}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
@ -548,6 +553,7 @@ def _test_processing_cache_correctness(
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
hf_overrides=hf_overrides,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
@ -580,18 +586,14 @@ def _test_processing_cache_correctness(
|
||||
min_wh=128,
|
||||
max_wh=256),
|
||||
"audio":
|
||||
partial(_rand_audio, rng, min_len=256, max_len=512, sr=16000),
|
||||
}
|
||||
input_max_count = {
|
||||
modality: 3 if supports_multi else 1
|
||||
for modality, supports_multi in modalities.items()
|
||||
partial(_rand_audio, rng, min_len=512, max_len=1024, sr=16000),
|
||||
}
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
mm_data = {
|
||||
k:
|
||||
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||
for _ in range(rng.randint(input_max_count[k]))]
|
||||
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
|
||||
for k in modalities
|
||||
}
|
||||
|
||||
|
@ -331,13 +331,7 @@ class InputRegistry:
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
processor = mm_registry.create_processor(model_config, tokenizer)
|
||||
|
||||
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
mm_max_tokens = mm_registry.get_max_tokens_by_modality(
|
||||
model_config)
|
||||
|
||||
dummy_data = processor.get_dummy_data(seq_len, mm_counts,
|
||||
mm_max_tokens)
|
||||
dummy_data = processor.get_dummy_data(seq_len)
|
||||
else:
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
if is_encoder_data:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -9,7 +9,6 @@ from transformers import BatchFeature, PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -87,8 +86,8 @@ class AriaVisionModel(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_mask: Optional[torch.BoolTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]:
|
||||
pixel_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
|
||||
vit_oup = self.vision_model(
|
||||
@ -100,7 +99,8 @@ class AriaVisionModel(nn.Module):
|
||||
|
||||
return vit_oup, image_atts
|
||||
|
||||
def _create_patch_attention_mask(self, pixel_mask):
|
||||
def _create_patch_attention_mask(
|
||||
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if pixel_mask is None:
|
||||
return None
|
||||
|
||||
@ -115,7 +115,8 @@ class AriaVisionModel(nn.Module):
|
||||
)
|
||||
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
def _create_image_attention_mask(self, patch_attention_mask):
|
||||
def _create_image_attention_mask(
|
||||
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
if patch_attention_mask is None:
|
||||
return None
|
||||
|
||||
@ -125,13 +126,13 @@ class AriaVisionModel(nn.Module):
|
||||
|
||||
class FFN(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, ff_dim, output_dim):
|
||||
def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
|
||||
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
|
||||
self.act = get_act_fn("gelu_new")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.linear_in(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states, _ = self.linear_out(hidden_states)
|
||||
@ -140,7 +141,7 @@ class FFN(nn.Module):
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
|
||||
def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
@ -149,12 +150,16 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.linear = nn.Linear(embed_dim, embed_dim)
|
||||
self.dropout = nn.Dropout(drop_out_rate)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
self.ln_kv = nn.LayerNorm(kv_dim)
|
||||
|
||||
def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
|
||||
|
||||
@ -169,11 +174,7 @@ class CrossAttention(nn.Module):
|
||||
|
||||
attn_output = attn_output.permute(1, 0, 2)
|
||||
|
||||
if add_residual:
|
||||
attn_output = hidden_states + self.dropout(
|
||||
self.linear(attn_output))
|
||||
else:
|
||||
attn_output = self.dropout(self.linear(attn_output))
|
||||
attn_output = self.linear(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
@ -201,14 +202,14 @@ class AriaProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_to_query_dict,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
kv_dim,
|
||||
ff_dim,
|
||||
output_dim,
|
||||
norm_layer=nn.LayerNorm,
|
||||
):
|
||||
patch_to_query_dict: dict[int, int],
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: int,
|
||||
ff_dim: int,
|
||||
output_dim: int,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_to_query_dict = patch_to_query_dict
|
||||
self.embed_dim = embed_dim
|
||||
@ -224,7 +225,11 @@ class AriaProjector(nn.Module):
|
||||
self.ln_ffn = norm_layer(embed_dim)
|
||||
self.ffn = FFN(embed_dim, ff_dim, output_dim)
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
bs = x.shape[0]
|
||||
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
|
||||
|
||||
@ -442,13 +447,18 @@ def build_mm_projector(config: PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
def get_max_aria_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config()
|
||||
return max(hf_config.projector_patch_to_query_dict.values())
|
||||
|
||||
|
||||
class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
return max(hf_config.projector_patch_to_query_dict.values())
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
@ -468,13 +478,13 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
max_image_tokens = get_max_aria_image_tokens(self.ctx)
|
||||
num_image_tokens = self._get_num_image_tokens()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * max_image_tokens,
|
||||
replacement=[image_token_id] * num_image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
@ -504,7 +514,6 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_aria_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
|
||||
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
"""
|
||||
|
@ -9,7 +9,6 @@ from transformers import (BatchFeature, Blip2Config, Blip2Processor,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@ -18,7 +17,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
@ -398,15 +396,17 @@ class Blip2QFormerModel(nn.Module):
|
||||
return sequence_output
|
||||
|
||||
|
||||
def get_max_blip2_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(Blip2Config)
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
|
||||
class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MultiModalDataParser(max_mm_counts={"image": 1})
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
hf_config = self.ctx.get_hf_config(Blip2Config)
|
||||
return hf_config.num_query_tokens
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> Blip2Processor:
|
||||
return self.ctx.get_hf_processor(Blip2Processor)
|
||||
@ -427,7 +427,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
max_image_tokens = get_max_blip2_image_tokens(self.ctx)
|
||||
max_image_tokens = self._get_num_image_tokens()
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@ -480,7 +480,6 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
|
@ -11,7 +11,6 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@ -31,7 +30,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
@ -43,11 +41,6 @@ from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# These configs are not part of the model config but the preprocessor
|
||||
# and processor files, so we hardcode them in the model file for now.
|
||||
CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512
|
||||
CHAMELEON_IMAGE_SEQ_LENGTH = 1024
|
||||
|
||||
|
||||
class ChameleonImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
@ -55,14 +48,17 @@ class ChameleonImagePixelInputs(TypedDict):
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
|
||||
|
||||
def get_max_chameleon_image_tokens(ctx: InputContext):
|
||||
return CHAMELEON_IMAGE_SEQ_LENGTH
|
||||
|
||||
|
||||
class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MultiModalDataParser(max_mm_counts={"image": 1})
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def _get_num_image_tokens(self) -> int:
|
||||
processor = self._get_hf_processor()
|
||||
return processor.image_seq_length
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": self._get_num_image_tokens()}
|
||||
|
||||
def _get_hf_processor(self) -> ChameleonProcessor:
|
||||
return self.ctx.get_hf_processor(ChameleonProcessor)
|
||||
@ -88,7 +84,7 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
target="<image>",
|
||||
replacement="".join([
|
||||
processor.image_start_token,
|
||||
processor.image_token * CHAMELEON_IMAGE_SEQ_LENGTH,
|
||||
processor.image_token * self._get_num_image_tokens(),
|
||||
processor.image_end_token,
|
||||
]),
|
||||
)
|
||||
@ -98,12 +94,15 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
config = self.ctx.get_hf_config(ChameleonConfig)
|
||||
|
||||
width = height = config.vq_config.resolution
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=CHAMELEON_CROP_SIZE_WIDTH,
|
||||
height=CHAMELEON_CROP_SIZE_HEIGHT,
|
||||
self._get_dummy_images(width=width,
|
||||
height=height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
@ -902,7 +901,6 @@ class ChameleonModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
|
||||
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
@ -931,9 +929,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT,
|
||||
CHAMELEON_CROP_SIZE_WIDTH)
|
||||
vq_config: ChameleonVQVAEConfig = self.config.vq_config
|
||||
expected_dims = (3, vq_config.resolution, vq_config.resolution)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
|
@ -25,7 +25,6 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
@ -34,7 +33,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
@ -48,9 +47,6 @@ from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
_IMAGE_TOKEN_ID = 71011
|
||||
_NEWLINE_TOKEN_ID = 71019
|
||||
|
||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
|
||||
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
|
||||
|
||||
|
||||
class FuyuImagePatchInputs(TypedDict):
|
||||
type: Literal["image_patches"]
|
||||
@ -67,43 +63,49 @@ class FuyuImagePatchInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
def _get_fuyu_num_image_tokens(
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculate the number of image tokens needed for a given image size.
|
||||
|
||||
The expected Fuyu image prompts can be expressed as:
|
||||
|
||||
.. code-block::
|
||||
(image_token * ncols + newline_token) * nrows
|
||||
|
||||
Args:
|
||||
image_size: Tuple[int, int] - `(width, height)` of the image
|
||||
|
||||
Returns:
|
||||
ncols: int - number of image tokens in `x` direction
|
||||
nrows: int - number of image tokens in `y` direction
|
||||
"""
|
||||
ncols = math.ceil(image_width / 30)
|
||||
nrows = math.ceil(image_height / 30)
|
||||
return ncols, nrows
|
||||
|
||||
|
||||
def get_max_fuyu_image_tokens(ctx: InputContext):
|
||||
ncols, nrows = _get_fuyu_num_image_tokens(
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
)
|
||||
|
||||
return (ncols + 1) * nrows
|
||||
|
||||
|
||||
class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MultiModalDataParser(max_mm_counts={"image": 1})
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def _get_image_target_size(self) -> ImageSize:
|
||||
processor = self._get_hf_processor()
|
||||
image_processor: FuyuImageProcessor = processor.image_processor
|
||||
|
||||
target_size = image_processor.size
|
||||
return ImageSize(width=target_size["width"],
|
||||
height=target_size["height"])
|
||||
|
||||
def _get_image_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
|
||||
if not (image_width <= target_width and image_height <= target_height):
|
||||
height_scale_factor = target_height / image_height
|
||||
width_scale_factor = target_width / image_width
|
||||
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||
|
||||
image_height = int(image_height * optimal_scale_factor)
|
||||
image_width = int(image_width * optimal_scale_factor)
|
||||
|
||||
ncols = math.ceil(image_width / 30)
|
||||
nrows = math.ceil(image_height / 30)
|
||||
return ncols, nrows
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
|
||||
max_ncols, max_nrows = self._get_image_grid_size(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
max_image_tokens = (max_ncols + 1) * max_nrows
|
||||
|
||||
return {"image": max_image_tokens}
|
||||
|
||||
def _get_hf_processor(self) -> FuyuProcessor:
|
||||
return self.ctx.get_hf_processor(FuyuProcessor)
|
||||
@ -166,28 +168,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
eot_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(eot_token_id, int)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor: FuyuImageProcessor = hf_processor.image_processor
|
||||
target_size = image_processor.size
|
||||
target_height, target_width = (target_size["height"],
|
||||
target_size["width"])
|
||||
|
||||
def get_replacement_fuyu(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
width, height = image_size.width, image_size.height
|
||||
if not (width <= target_width and height <= target_height):
|
||||
height_scale_factor = target_height / height
|
||||
width_scale_factor = target_width / width
|
||||
optimal_scale_factor = min(height_scale_factor,
|
||||
width_scale_factor)
|
||||
|
||||
height = int(height * optimal_scale_factor)
|
||||
width = int(width * optimal_scale_factor)
|
||||
|
||||
ncols, nrows = _get_fuyu_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
ncols, nrows = self._get_image_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
|
||||
@ -225,12 +212,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
target_width, target_height = self._get_image_target_size()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
@ -240,7 +228,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
|
@ -119,6 +119,12 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {"image": get_max_llava_image_tokens(self.ctx)}
|
||||
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
|
||||
|
||||
@ -324,7 +330,6 @@ def init_vision_tower_for_llava(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
@ -649,7 +654,6 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
# To use this model, please use
|
||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
|
||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
pass
|
||||
|
@ -23,7 +23,6 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@ -306,25 +305,32 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
return image_features_hd_newline
|
||||
|
||||
|
||||
def get_max_phi3v_image_tokens(
|
||||
ctx: InputContext,
|
||||
*,
|
||||
num_crops: Optional[int] = None,
|
||||
) -> int:
|
||||
hf_processor_mm_kwargs = {}
|
||||
if num_crops:
|
||||
hf_processor_mm_kwargs["num_crops"] = num_crops
|
||||
|
||||
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
return processor.calc_num_image_tokens_from_image_size(
|
||||
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
|
||||
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
processor = self._get_hf_processor()
|
||||
|
||||
return processor.calc_num_image_tokens_from_image_size( # type: ignore
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
max_image_tokens = self._get_num_image_tokens(
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
return {"image": max_image_tokens}
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
@ -332,6 +338,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> ProcessorMixin:
|
||||
if num_crops is not None:
|
||||
return self.ctx.get_hf_processor(num_crops=num_crops)
|
||||
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _call_hf_processor(
|
||||
@ -375,7 +382,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
@ -385,9 +391,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
|
||||
width=image_size.width,
|
||||
height=image_size.height,
|
||||
num_tokens = self._get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
|
||||
@ -467,7 +473,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
|
@ -33,13 +33,12 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
@ -80,15 +79,18 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
return feat_lengths, output_lengths
|
||||
|
||||
|
||||
def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
|
||||
hf_config = ctx.get_hf_config(Qwen2AudioConfig)
|
||||
max_source_position = hf_config.audio_config.max_source_positions
|
||||
output_lengths = (max_source_position - 2) // 2 + 1
|
||||
return output_lengths
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
max_source_positions = hf_config.audio_config.max_source_positions
|
||||
max_output_lengths = (max_source_positions - 2) // 2 + 1
|
||||
|
||||
return {"audio": max_output_lengths}
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
@ -157,11 +159,21 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||
_, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
_, audio_output_lens = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1))
|
||||
|
||||
audio_output_lengths = audio_output_lens.tolist()
|
||||
|
||||
def get_replacement_qwen2_audio(item_idx: int):
|
||||
return [placeholder] * audio_output_lengths[item_idx]
|
||||
num_placeholders = audio_output_lengths[item_idx]
|
||||
if num_placeholders == 0:
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
audio = audios.get(item_idx)
|
||||
raise ValueError(
|
||||
f"The audio {audio} (len={len(audio)}) is too short "
|
||||
"to be represented inside the model")
|
||||
|
||||
return [placeholder] * num_placeholders
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@ -171,6 +183,14 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
]
|
||||
|
||||
def _always_apply_prompt_replacements(self) -> bool:
|
||||
# HF never applies prompt replacements, so we have to do it ourselves
|
||||
# _find_placeholders may incorrectly think that HF has already performed
|
||||
# processing for multi-audio input when the input audios are short
|
||||
# (the corresponding placeholders may take up fewer tokens than
|
||||
# the number of audio items)
|
||||
return True
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
@ -192,8 +212,6 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_max_qwen2_audio_audio_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
@ -40,7 +40,6 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
@ -650,8 +649,9 @@ def _get_vision_info(
|
||||
width: int,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
*,
|
||||
do_resize: bool = True,
|
||||
data_type_key: str = "image",
|
||||
modality: str = "image",
|
||||
mm_count: int = 1,
|
||||
):
|
||||
"""Get information (resized height / width and number of vision tokens)
|
||||
@ -671,11 +671,12 @@ def _get_vision_info(
|
||||
else:
|
||||
resized_height, resized_width = height, width
|
||||
|
||||
if data_type_key == "image":
|
||||
if modality == "image":
|
||||
grid_t = mm_count
|
||||
else:
|
||||
assert data_type_key == "video"
|
||||
elif modality == "video":
|
||||
grid_t = max(mm_count // temporal_patch_size, 1)
|
||||
else:
|
||||
raise ValueError(f"Modality {modality} is not supported")
|
||||
|
||||
grid_h = resized_height // patch_size
|
||||
grid_w = resized_width // patch_size
|
||||
@ -691,41 +692,11 @@ def _get_image_processor(hf_processor: Qwen2VLProcessor):
|
||||
return image_processor
|
||||
|
||||
|
||||
def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
|
||||
data_type_key: str,
|
||||
*,
|
||||
min_pixels: Optional[int] = None,
|
||||
max_pixels: Optional[int] = None) -> int:
|
||||
hf_config = ctx.get_hf_config(Qwen2VLConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
hf_processor = ctx.get_hf_processor(Qwen2VLProcessor)
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
_, _, max_llm_image_tokens = _get_vision_info(
|
||||
vision_config,
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
min_pixels=min_pixels or image_processor.min_pixels,
|
||||
max_pixels=max_pixels or image_processor.max_pixels,
|
||||
data_type_key=data_type_key,
|
||||
)
|
||||
return max_llm_image_tokens
|
||||
|
||||
|
||||
get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens,
|
||||
data_type_key="image")
|
||||
get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
|
||||
data_type_key="video")
|
||||
|
||||
|
||||
class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
dict[str, torch.Tensor]]):
|
||||
|
||||
def __init__(self, data: dict, modality: str) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
self.modality = modality
|
||||
super().__init__(data, modality)
|
||||
|
||||
grid_thw = data[f"{modality}_grid_thw"]
|
||||
slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
@ -734,9 +705,6 @@ class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
|
||||
for i in range(len(grid_thw))
|
||||
]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r})")
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data[f"{self.modality}_grid_thw"])
|
||||
|
||||
@ -792,6 +760,32 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def _get_max_mm_tokens(self, modality: str) -> int:
|
||||
hf_config = self.ctx.get_hf_config(Qwen2VLConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = _get_image_processor(hf_processor)
|
||||
|
||||
_, _, max_llm_image_tokens = _get_vision_info(
|
||||
vision_config,
|
||||
height=9999999,
|
||||
width=9999999,
|
||||
min_pixels=image_processor.min_pixels,
|
||||
max_pixels=image_processor.max_pixels,
|
||||
modality=modality,
|
||||
)
|
||||
return max_llm_image_tokens
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self._get_max_mm_tokens("image"),
|
||||
"video": self._get_max_mm_tokens("video"),
|
||||
}
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return Qwen2MultiModalDataParser()
|
||||
|
||||
@ -908,9 +902,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"video", get_max_qwen2_vl_video_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
|
||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP):
|
||||
|
@ -2,7 +2,7 @@
|
||||
"""PyTorch Ultravox model."""
|
||||
|
||||
import math
|
||||
from functools import cached_property, lru_cache
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@ -17,7 +17,6 @@ from transformers.models.whisper.modeling_whisper import WhisperEncoder
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@ -58,23 +57,18 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
|
||||
UltravoxAudioEmbeddingInputs]
|
||||
|
||||
|
||||
@lru_cache
|
||||
def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor:
|
||||
return WhisperFeatureExtractor.from_pretrained(model_id)
|
||||
|
||||
|
||||
def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor:
|
||||
hf_config = ctx.get_hf_config(UltravoxConfig)
|
||||
return cached_feature_extractor(hf_config.audio_model_id)
|
||||
|
||||
|
||||
def get_ultravox_max_audio_tokens(ctx: InputContext):
|
||||
feature_extractor = whisper_feature_extractor(ctx)
|
||||
return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND)
|
||||
|
||||
|
||||
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
feature_extractor = self._get_feature_extractor()
|
||||
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
|
||||
_AUDIO_TOKENS_PER_SECOND)
|
||||
|
||||
return {"audio": max_audio_tokens}
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
@ -322,8 +316,6 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_ultravox_max_audio_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
|
||||
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
|
@ -21,10 +21,15 @@ _I = TypeVar("_I")
|
||||
|
||||
class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
|
||||
def __init__(self, data: _T) -> None:
|
||||
def __init__(self, data: _T, modality: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.data = data
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
||||
f"len={len(self)})")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.get_count()
|
||||
@ -64,14 +69,6 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
|
||||
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
|
||||
def __init__(self, data: Sequence[_T], modality: str) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r})")
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
@ -87,14 +84,6 @@ class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
|
||||
class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
|
||||
|
||||
def __init__(self, data: NestedTensors, modality: str) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
self.modality = modality
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r})")
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
@ -222,22 +211,13 @@ class MultiModalDataParser:
|
||||
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
|
||||
|
||||
Args:
|
||||
max_mm_counts (Mapping[str, int]): The maximum allowed number of items
|
||||
belonging to each modality. This effectively sets a hard limit over
|
||||
`--limit-mm-per-prompt`.
|
||||
target_sr (float, optional): Enables automatic resampling of audio
|
||||
items to the model's expected sampling rate.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_mm_counts: Mapping[str, int] = {},
|
||||
target_sr: Optional[float] = None,
|
||||
) -> None:
|
||||
def __init__(self, *, target_sr: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_mm_counts = max_mm_counts
|
||||
self.target_sr = target_sr
|
||||
|
||||
def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
|
||||
@ -345,7 +325,6 @@ class MultiModalDataParser:
|
||||
|
||||
def parse_mm_data(self,
|
||||
mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
max_mm_counts = self.max_mm_counts
|
||||
subparsers = self._get_subparsers()
|
||||
|
||||
mm_items = MultiModalDataItems()
|
||||
@ -353,16 +332,6 @@ class MultiModalDataParser:
|
||||
if k not in subparsers:
|
||||
raise ValueError(f"Unsupported modality: {k}")
|
||||
|
||||
modality_items = subparsers[k](v)
|
||||
|
||||
if k in max_mm_counts:
|
||||
max_count = max_mm_counts[k]
|
||||
if len(modality_items) > max_count:
|
||||
raise ValueError(
|
||||
f"This model supports at most {max_count} {k} items "
|
||||
f"per prompt, but {len(modality_items)} {k} items "
|
||||
"were given or set as its limit_mm_per_prompt.")
|
||||
|
||||
mm_items[k] = modality_items
|
||||
mm_items[k] = subparsers[k](v)
|
||||
|
||||
return mm_items
|
||||
|
@ -624,6 +624,29 @@ class BaseMultiModalProcessor(ABC):
|
||||
) -> MultiModalInputsV2:
|
||||
return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
"""
|
||||
Return the maximum supported number of items for each modality.
|
||||
|
||||
A value of `None` means unlimited number of items.
|
||||
|
||||
Omitting a modality from the returned dictionary means that
|
||||
it is not supported at all.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum possible number of tokens per data item
|
||||
for each modality.
|
||||
|
||||
The dictionary returned by this method should have the same
|
||||
keys as that returned by :meth:`get_supported_mm_limits`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
"""
|
||||
Construct a data parser to preprocess multi-modal data items
|
||||
@ -653,7 +676,18 @@ class BaseMultiModalProcessor(ABC):
|
||||
before passing them to :meth:`_get_hf_mm_data`.
|
||||
"""
|
||||
parser = self._get_data_parser()
|
||||
return parser.parse_mm_data(mm_data)
|
||||
mm_items = parser.parse_mm_data(mm_data)
|
||||
|
||||
mm_limits = self.ctx.get_mm_config().limit_per_prompt
|
||||
for modality, items in mm_items.items():
|
||||
limit = mm_limits.get(modality, 1)
|
||||
if len(items) > limit:
|
||||
raise ValueError(
|
||||
f"You set {modality}={limit} (or defaulted to 1) in "
|
||||
f"`--limit-mm-per-prompt`, but passed {len(items)} "
|
||||
f"{modality} items in the same prompt.")
|
||||
|
||||
return mm_items
|
||||
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
@ -901,6 +935,17 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
|
||||
|
||||
def _always_apply_prompt_replacements(self) -> bool:
|
||||
"""
|
||||
A flag which can be overridden so that
|
||||
:meth:`_apply_prompt_replacements` is always called even if we
|
||||
detect that HF has performed processing via :meth:`_find_placeholders`.
|
||||
|
||||
This is useful in cases where :meth:`_find_placeholders` cannot be
|
||||
reliably used to detect whether HF has performed processing or not.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
@ -995,7 +1040,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
|
||||
mm_item_counts)
|
||||
|
||||
if all_placeholders:
|
||||
if all_placeholders and not self._always_apply_prompt_replacements():
|
||||
tokenizer = self._get_tokenizer()
|
||||
prompt_text = _decode(tokenizer, prompt_ids)
|
||||
else:
|
||||
@ -1009,10 +1054,27 @@ class BaseMultiModalProcessor(ABC):
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
mm_placeholders = {
|
||||
modality: [item.to_range() for item in items]
|
||||
for modality, items in full_groupby_modality(all_placeholders)
|
||||
}
|
||||
mm_placeholders = dict[str, list[PlaceholderRange]]()
|
||||
err_suffix = ("This suggests a problem with your implementation of "
|
||||
"the merged multi-modal processor for this model, "
|
||||
"particularly in the `_get_prompt_replacements` method.")
|
||||
|
||||
for modality, placeholders in full_groupby_modality(all_placeholders):
|
||||
if modality not in mm_items:
|
||||
raise AssertionError(
|
||||
f"Expected no placeholders for {modality=}, "
|
||||
f"but found {placeholders=}. Input items: {mm_items}"
|
||||
f"\n{err_suffix}")
|
||||
|
||||
if len(placeholders) != len(mm_items[modality]):
|
||||
raise AssertionError(
|
||||
f"Expected length of {placeholders=} for {modality=} "
|
||||
f"to equal that of input items: {mm_items[modality]}"
|
||||
f"\n{err_suffix}")
|
||||
|
||||
mm_placeholders[modality] = [
|
||||
item.to_range() for item in placeholders
|
||||
]
|
||||
|
||||
return MultiModalInputsV2(
|
||||
type="multimodal",
|
||||
@ -1063,15 +1125,38 @@ class BaseMultiModalProcessor(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_dummy_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_max_tokens: Mapping[str, int],
|
||||
) -> DummyData:
|
||||
def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]:
|
||||
mm_limit_per_prompt = self.ctx.get_mm_config().limit_per_prompt
|
||||
supported_mm_limits = self.get_supported_mm_limits()
|
||||
|
||||
mm_limits = {
|
||||
modality: mm_limit_per_prompt.get(modality, 1)
|
||||
for modality in supported_mm_limits
|
||||
}
|
||||
|
||||
for modality, supported_limit in supported_mm_limits.items():
|
||||
limit = mm_limits[modality]
|
||||
if supported_limit is not None and supported_limit < limit:
|
||||
raise ValueError(
|
||||
f"You set {modality}={limit} (or defaulted to 1) in "
|
||||
f"`--limit-mm-per-prompt`, but this model only supports "
|
||||
f"at most {supported_limit} {modality} items.")
|
||||
|
||||
return mm_limits
|
||||
|
||||
def get_dummy_data(self, seq_len: int) -> DummyData:
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
mm_counts = self._get_and_validate_dummy_mm_counts()
|
||||
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item()
|
||||
if mm_counts.keys() != mm_max_tokens_per_item.keys():
|
||||
raise AssertionError(
|
||||
"The keys returned by `get_supported_mm_limits`"
|
||||
f"({set(mm_counts.keys())}) should be the same as those "
|
||||
"returned by `get_mm_max_tokens_per_item` "
|
||||
f"({set(mm_max_tokens_per_item.keys())})")
|
||||
|
||||
processor_inputs = self._get_dummy_mm_inputs(mm_counts)
|
||||
mm_inputs = self.apply(
|
||||
prompt_text=processor_inputs.prompt_text,
|
||||
@ -1087,7 +1172,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
expected_placeholders_by_modality = {
|
||||
modality: mm_max_tokens[modality]
|
||||
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
|
||||
for modality in placeholders_by_modality
|
||||
}
|
||||
if total_placeholders_by_modality != expected_placeholders_by_modality:
|
||||
|
@ -15,6 +15,7 @@ from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
|
||||
from .image import ImagePlugin
|
||||
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
|
||||
from .processing import BaseMultiModalProcessor, ProcessingCache
|
||||
from .utils import cached_get_tokenizer
|
||||
from .video import VideoPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -219,6 +220,10 @@ class MultiModalRegistry:
|
||||
Note:
|
||||
This is currently directly used only in V1.
|
||||
"""
|
||||
if self.has_processor(model_config):
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
processor = self.create_processor(model_config, tokenizer)
|
||||
return processor.get_mm_max_tokens_per_item()
|
||||
|
||||
return {
|
||||
key: plugin.get_max_multimodal_tokens(model_config)
|
||||
|
Loading…
x
Reference in New Issue
Block a user