[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-08 18:59:58 +08:00 committed by GitHub
parent f12141170a
commit 2a0596bc48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 833 additions and 760 deletions

View File

@ -4,24 +4,17 @@ from functools import partial
import pytest import pytest
from PIL import Image from PIL import Image
from pqdm.threads import pqdm from pqdm.threads import pqdm
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import ImageSize from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import cached_get_tokenizer
from ....utils import build_model_context from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_next():
from vllm.model_executor.models.llava_next import (
LlavaNextMultiModalProcessor)
return LlavaNextMultiModalProcessor
def _validate_image_prompt_replacements_one( def _validate_image_prompt_replacements_one(
processor, processor: BaseMultiModalProcessor,
num_imgs: int, num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]], failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize, image_size: ImageSize,
@ -78,20 +71,17 @@ def _test_image_prompt_replacements(
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements_regression( def test_processor_prompt_replacements_regression(model_id, num_imgs):
processor_for_llava_next,
model_id: str,
num_imgs: int,
):
ctx = build_model_context( ctx = build_model_context(
model_name=model_id, model_name=model_id,
tokenizer_name=model_id, tokenizer_name=model_id,
mm_processor_kwargs=None, mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) processor = MULTIMODAL_REGISTRY.create_processor(
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx.model_config,
processor = processor_for_llava_next(ctx) tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)] (488, 183), (2560, 1669)]
@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression(
"Comment this out to run it manually.") "Comment this out to run it manually.")
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1]) @pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all( def test_processor_prompt_replacements_all(model_id, num_imgs):
processor_for_llava_next,
model_id: str,
num_imgs: int,
):
ctx = build_model_context( ctx = build_model_context(
model_name=model_id, model_name=model_id,
tokenizer_name=model_id, tokenizer_name=model_id,
mm_processor_kwargs=None, mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) processor = MULTIMODAL_REGISTRY.create_processor(
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx.model_config,
processor = processor_for_llava_next(ctx) tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
seen_aspect_ratios = set[float]() seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]() image_sizes = list[ImageSize]()

View File

@ -4,24 +4,17 @@ from functools import partial
import pytest import pytest
from PIL import Image from PIL import Image
from pqdm.threads import pqdm from pqdm.threads import pqdm
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import ImageSize from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import cached_get_tokenizer
from ....utils import build_model_context from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_onevision():
from vllm.model_executor.models.llava_onevision import (
LlavaOnevisionMultiModalProcessor)
return LlavaOnevisionMultiModalProcessor
def _validate_image_prompt_replacements_one( def _validate_image_prompt_replacements_one(
processor, processor: BaseMultiModalProcessor,
num_imgs: int, num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]], failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize, image_size: ImageSize,
@ -77,20 +70,17 @@ def _test_image_prompt_replacements(
@pytest.mark.parametrize("model_id", @pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements_regression( def test_processor_prompt_replacements_regression(model_id, num_imgs):
processor_for_llava_onevision,
model_id: str,
num_imgs: int,
):
ctx = build_model_context( ctx = build_model_context(
model_name=model_id, model_name=model_id,
tokenizer_name=model_id, tokenizer_name=model_id,
mm_processor_kwargs=None, mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) processor = MULTIMODAL_REGISTRY.create_processor(
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx.model_config,
processor = processor_for_llava_onevision(ctx) tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328), image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)] (488, 183), (2560, 1669)]
@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression(
@pytest.mark.parametrize("model_id", @pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("num_imgs", [1]) @pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all( def test_processor_prompt_replacements_all(model_id, num_imgs):
processor_for_llava_onevision,
model_id: str,
num_imgs: int,
):
ctx = build_model_context( ctx = build_model_context(
model_name=model_id, model_name=model_id,
tokenizer_name=model_id, tokenizer_name=model_id,
mm_processor_kwargs=None, mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) processor = MULTIMODAL_REGISTRY.create_processor(
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx.model_config,
processor = processor_for_llava_onevision(ctx) tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
seen_aspect_ratios = set[float]() seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]() image_sizes = list[ImageSize]()

View File

@ -1,21 +1,13 @@
"""Tests for phi3v's multimodal preprocessing kwargs.""" """Tests for phi3v's multimodal preprocessing kwargs."""
import pytest import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from vllm.multimodal.utils import cached_get_tokenizer
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
from ....utils import build_model_context from ....utils import build_model_context
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VMultiModalProcessor
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"]) @pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -29,7 +21,6 @@ def processor_for_phi3v():
# yapf: enable # yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override( def test_processor_override(
processor_for_phi3v,
image_assets: _ImageAssets, image_assets: _ImageAssets,
model_id: str, model_id: str,
mm_processor_kwargs: dict[str, int], mm_processor_kwargs: dict[str, int],
@ -37,21 +28,26 @@ def test_processor_override(
num_imgs: int, num_imgs: int,
): ):
"""Ensure input_processor_for_phi3v handles num_crops properly.""" """Ensure input_processor_for_phi3v handles num_crops properly."""
# Avoid initializing CUDA early
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
ctx = build_model_context( ctx = build_model_context(
model_name=model_id, model_name=model_id,
tokenizer_name=model_id, tokenizer_name=model_id,
trust_remote_code=True, trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
ctx = InputProcessingContext(ctx.model_config, tokenizer) processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
# Build the image str / prompt based on the number of images we pass # Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
mm_data = {"image": [image_assets[0].pil_image] * num_imgs} mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processor = processor_for_phi3v(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size

View File

@ -1,19 +1,12 @@
import pytest import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
from ....utils import build_model_context from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
return Qwen2VLMultiModalProcessor
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) @pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -24,7 +17,6 @@ def processor_for_qwen2_vl():
# yapf: enable # yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override( def test_processor_override(
processor_for_qwen2_vl,
image_assets: _ImageAssets, image_assets: _ImageAssets,
model_id: str, model_id: str,
mm_processor_kwargs: dict[str, object], mm_processor_kwargs: dict[str, object],
@ -39,18 +31,20 @@ def test_processor_override(
mm_processor_kwargs=None, mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
ctx = InputProcessingContext(ctx.model_config, tokenizer) processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
# Build the image str / prompt based on the number of images we pass # Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
mm_data = {"image": [image_assets[0].pil_image] * num_imgs} mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processor = processor_for_qwen2_vl(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
hf_processor = processor._get_hf_processor(**mm_processor_kwargs) hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token) image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape

View File

@ -10,12 +10,17 @@ from PIL import Image
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, # yapf conflicts with isort for this block
_PlaceholderInfo, find_mm_placeholders, # yapf: disable
from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache,
PromptReplacement,
find_mm_placeholders,
find_text_matches, find_token_matches, find_text_matches, find_token_matches,
iter_token_matches, iter_token_matches,
replace_text_matches, replace_text_matches,
replace_token_matches) replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby from vllm.utils import full_groupby
@ -431,7 +436,7 @@ def test_find_replace_tokens(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
{ {
"pattern_1": [ "pattern_1": [
_PlaceholderInfo( PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=0, item_idx=0,
start_idx=6, start_idx=6,
@ -445,13 +450,13 @@ def test_find_replace_tokens(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
{ {
"pattern_1": [ "pattern_1": [
_PlaceholderInfo( PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=0, item_idx=0,
start_idx=1, start_idx=1,
replacement=[32000, 32000], replacement=[32000, 32000],
), ),
_PlaceholderInfo( PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1, item_idx=1,
start_idx=5, start_idx=5,
@ -459,7 +464,7 @@ def test_find_replace_tokens(
), ),
], ],
"pattern_3": [ "pattern_3": [
_PlaceholderInfo( PlaceholderInfo(
modality="pattern_3", modality="pattern_3",
item_idx=0, item_idx=0,
start_idx=7, start_idx=7,
@ -472,13 +477,13 @@ def test_find_replace_tokens(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
{ {
"pattern_1": [ "pattern_1": [
_PlaceholderInfo( PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=0, item_idx=0,
start_idx=1, start_idx=1,
replacement=[32000, 32000], replacement=[32000, 32000],
), ),
_PlaceholderInfo( PlaceholderInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1, item_idx=1,
start_idx=3, start_idx=3,
@ -486,7 +491,7 @@ def test_find_replace_tokens(
), ),
], ],
"pattern_3": [ "pattern_3": [
_PlaceholderInfo( PlaceholderInfo(
modality="pattern_3", modality="pattern_3",
item_idx=0, item_idx=0,
start_idx=6, start_idx=6,
@ -577,19 +582,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
revision=None, revision=None,
limit_mm_per_prompt=limit_mm_per_prompt, limit_mm_per_prompt=limit_mm_per_prompt,
) )
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] processor = MULTIMODAL_REGISTRY.create_processor(
ctx = InputProcessingContext(
model_config, model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer), tokenizer=cached_get_tokenizer(model_config.tokenizer),
) )
profiler = MultiModalProfiler(processor)
processor = processor_factory(ctx, cache=None)
profiler = processor.profiling_info
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported}) mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
profiler.get_supported_mm_limits = mock_supported_mm_limits processor.info.get_supported_mm_limits = mock_supported_mm_limits
if is_valid: if is_valid:
exc_ctx = nullcontext() exc_ctx = nullcontext()
@ -597,7 +598,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = pytest.raises(ValueError, match="this model only supports") exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx: with exc_ctx:
profiler.get_mm_limits() profiler.get_dummy_data(model_config.max_model_len)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@ -620,16 +621,12 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
revision=None, revision=None,
limit_mm_per_prompt=limit_mm_per_prompt, limit_mm_per_prompt=limit_mm_per_prompt,
) )
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] processor = MULTIMODAL_REGISTRY.create_processor(
ctx = InputProcessingContext(
model_config, model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer), tokenizer=cached_get_tokenizer(model_config.tokenizer),
) )
processor = processor_factory(ctx, cache=None)
rng = np.random.RandomState(0) rng = np.random.RandomState(0)
image = _rand_img(rng, min_wh=128, max_wh=256) image = _rand_img(rng, min_wh=128, max_wh=256)
if num_images == 0: if num_images == 0:
@ -681,9 +678,9 @@ def _test_processing_cache_correctness(
hf_overrides=hf_overrides, hf_overrides=hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt, limit_mm_per_prompt=limit_mm_per_prompt,
) )
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls] model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext( ctx = InputProcessingContext(
model_config, model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer), tokenizer=cached_get_tokenizer(model_config.tokenizer),
@ -691,8 +688,9 @@ def _test_processing_cache_correctness(
# Ensure that it can fit all of the data # Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30) cache = ProcessingCache(capacity=1 << 30)
baseline_processor = processor_factory(ctx, cache=None) baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = processor_factory(ctx, cache=cache) cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
rng = np.random.RandomState(0) rng = np.random.RandomState(0)
@ -724,7 +722,7 @@ def _test_processing_cache_correctness(
} }
mm_counts = {k: len(vs) for k, vs in mm_data.items()} mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs( prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len, model_config.max_model_len,
mm_counts, mm_counts,
).prompt_text ).prompt_text

View File

@ -2,13 +2,17 @@ from typing import Optional
import torch import torch
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder,
LlavaMultiModalProcessor) LlavaForConditionalGeneration,
LlavaMultiModalProcessor,
LlavaProcessingInfo)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MyLlava(LlavaForConditionalGeneration): class MyLlava(LlavaForConditionalGeneration):
def compute_logits( def compute_logits(

View File

@ -7,7 +7,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2 from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_info_once, print_warning_once from vllm.utils import print_info_once, print_warning_once

View File

@ -323,6 +323,7 @@ class InputRegistry:
# Avoid circular import # Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
if mm_registry.has_processor(model_config): if mm_registry.has_processor(model_config):
@ -331,7 +332,8 @@ class InputRegistry:
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
processor = mm_registry.create_processor(model_config, tokenizer) processor = mm_registry.create_processor(model_config, tokenizer)
dummy_data = processor.get_dummy_data(seq_len) profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(seq_len)
else: else:
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
if is_encoder_data: if is_encoder_data:

View File

@ -23,10 +23,10 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin, BaseProcessingInfo, PromptReplacement)
PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig) AriaVisionConfig)
@ -445,33 +445,33 @@ def build_mm_projector(config: PretrainedConfig):
) )
class AriaProcessingMixin(ProcessingMixin): class AriaProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config() return self.ctx.get_hf_config()
def _get_vision_config(self) -> AriaVisionConfig: def get_vision_config(self) -> AriaVisionConfig:
return self._get_hf_config().vision_config return self.get_hf_config().vision_config
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
vision_config = self._get_vision_config() vision_config = self.info.get_vision_config()
max_image_size = vision_config.image_size max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
@ -483,7 +483,7 @@ class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
num_images=num_images) num_images=num_images)
} }
hf_processor = self._get_hf_processor() hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore image_token: str = hf_processor.image_token # type: ignore
return ProcessorInputs( return ProcessorInputs(
@ -492,10 +492,7 @@ class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
) )
class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor): class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return AriaProfilingInfo(self.ctx)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
@ -513,10 +510,10 @@ class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_config = self._get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
num_image_tokens = self._get_num_image_tokens() num_image_tokens = self.info.get_num_image_tokens()
return [ return [
PromptReplacement( PromptReplacement(
@ -527,7 +524,9 @@ class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
] ]
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor,
info=AriaProcessingInfo,
dummy_inputs=AriaDummyInputsBuilder)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
""" """
Aria model for conditional generation tasks. Aria model for conditional generation tasks.

View File

@ -17,10 +17,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin, BaseProcessingInfo, PromptReplacement)
PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel from .blip import BlipVisionModel
@ -397,30 +397,30 @@ class Blip2QFormerModel(nn.Module):
return sequence_output return sequence_output
class Blip2ProcessingMixin(ProcessingMixin): class Blip2ProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(Blip2Config) return self.ctx.get_hf_config(Blip2Config)
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return hf_config.num_query_tokens
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
return hf_config.num_query_tokens
class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
hf_config = self._get_hf_config() hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
max_image_size = vision_config.image_size max_image_size = vision_config.image_size
@ -439,10 +439,7 @@ class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
) )
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor): class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Blip2ProfilingInfo(self.ctx)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
@ -460,7 +457,7 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
num_image_tokens = self._get_num_image_tokens() num_image_tokens = self.info.get_num_image_tokens()
return [ return [
PromptReplacement( PromptReplacement(
@ -491,7 +488,9 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
return result return result
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo,
dummy_inputs=Blip2DummyInputsBuilder)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -30,10 +30,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin, BaseProcessingInfo, PromptReplacement)
PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
@ -49,33 +49,34 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
class ChameleonProcessingMixin(ProcessingMixin): class ChameleonProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig) return self.ctx.get_hf_config(ChameleonConfig)
def _get_hf_processor(self): def get_hf_processor(self):
return self.ctx.get_hf_processor(ChameleonProcessor) return self.ctx.get_hf_processor(ChameleonProcessor)
def _get_num_image_tokens(self) -> int:
processor = self._get_hf_processor()
return processor.image_seq_length
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
processor = self.get_hf_processor()
return processor.image_seq_length
class ChameleonDummyInputsBuilder(
BaseDummyInputsBuilder[ChameleonProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
config = self._get_hf_config() config = self.info.get_hf_config()
width = height = config.vq_config.resolution width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
@ -93,11 +94,8 @@ class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
) )
class ChameleonMultiModalProcessor(ChameleonProcessingMixin, class ChameleonMultiModalProcessor(
BaseMultiModalProcessor): BaseMultiModalProcessor[ChameleonProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return ChameleonProfilingInfo(self.ctx)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
@ -112,7 +110,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
processor = self._get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
return [ return [
PromptReplacement( PromptReplacement(
@ -120,7 +118,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
target="<image>", target="<image>",
replacement="".join([ replacement="".join([
processor.image_start_token, processor.image_start_token,
processor.image_token * self._get_num_image_tokens(), processor.image_token * self.info.get_num_image_tokens(),
processor.image_end_token, processor.image_end_token,
]), ]),
) )
@ -916,7 +914,10 @@ class ChameleonModel(nn.Module):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(
ChameleonMultiModalProcessor,
info=ChameleonProcessingInfo,
dummy_inputs=ChameleonDummyInputsBuilder)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):

View File

@ -33,11 +33,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin, BaseProcessingInfo, PromptReplacement)
PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
@ -64,24 +64,38 @@ class FuyuImagePatchInputs(TypedDict):
""" """
class FuyuProcessingMixin(ProcessingMixin): class FuyuProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(FuyuConfig) return self.ctx.get_hf_config(FuyuConfig)
def _get_hf_processor(self): def get_hf_processor(self):
return self.ctx.get_hf_processor(FuyuProcessor) return self.ctx.get_hf_processor(FuyuProcessor)
def _get_image_processor(self) -> FuyuImageProcessor: def get_image_processor(self) -> FuyuImageProcessor:
return self._get_hf_processor().image_processor return self.get_hf_processor().image_processor
def _get_image_feature_grid_size( def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_ncols, max_nrows = self.get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def get_image_feature_grid_size(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
) -> tuple[int, int]: ) -> tuple[int, int]:
image_processor = self._get_image_processor() image_processor = self.get_image_processor()
target_width = image_processor.size["width"] target_width = image_processor.size["width"]
target_height = image_processor.size["height"] target_height = image_processor.size["height"]
@ -97,34 +111,21 @@ class FuyuProcessingMixin(ProcessingMixin):
nrows = math.ceil(image_height / 30) nrows = math.ceil(image_height / 30)
return ncols, nrows return ncols, nrows
def get_image_size_with_most_features(self) -> ImageSize:
class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo): image_processor = self.get_image_processor()
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_ncols, max_nrows = self._get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
image_processor = self._get_image_processor()
return ImageSize(width=image_processor.size["width"], return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"]) height=image_processor.size["height"])
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
mm_data = { mm_data = {
@ -140,10 +141,7 @@ class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
) )
class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor): class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return FuyuProfilingInfo(self.ctx)
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -156,7 +154,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
# Avoid warning from HF logger for text-only input # Avoid warning from HF logger for text-only input
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id # Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
# Tokenizer won't add boa_token_id by default, we add it manually. # Tokenizer won't add boa_token_id by default, we add it manually.
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id] prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
@ -196,10 +194,10 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_config = self._get_hf_config() hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id bos_token_id = hf_config.bos_token_id
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
eot_token_id = tokenizer.bos_token_id eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int) assert isinstance(eot_token_id, int)
@ -207,7 +205,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
ncols, nrows = self._get_image_feature_grid_size( ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
) )
@ -244,7 +242,9 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
return result return result
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
info=FuyuProcessingInfo,
dummy_inputs=FuyuDummyInputsBuilder)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod from abc import abstractmethod
from functools import cached_property from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional, from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union) Protocol, Set, Tuple, TypedDict, TypeVar, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -25,11 +25,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingCache, BaseProcessingInfo, ProcessingCache,
ProcessingMixin, PromptReplacement) PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
@ -105,34 +105,23 @@ class LlavaLikeProcessor(Protocol):
image_token: Final[str] image_token: Final[str]
class BaseLlavaProcessingMixin(ProcessingMixin, ABC): class BaseLlavaProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self) -> LlavaLikeConfig: def get_hf_config(self) -> LlavaLikeConfig:
return self.ctx.get_hf_config(LlavaConfig) return self.ctx.get_hf_config(LlavaConfig)
def _get_vision_encoder_info(self): def get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config()) return get_vision_encoder_info(self.get_hf_config())
@abstractmethod @abstractmethod
def _get_hf_processor(self) -> LlavaLikeProcessor: def get_hf_processor(self) -> LlavaLikeProcessor:
raise NotImplementedError raise NotImplementedError
def _get_num_image_tokens( def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
self, return {"image": None}
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._get_vision_encoder_info()
return self._apply_feature_select_strategy( def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config.vision_feature_select_strategy, return {"image": self.get_max_image_tokens()}
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def _apply_feature_select_strategy( def _apply_feature_select_strategy(
self, self,
@ -147,28 +136,42 @@ class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
msg = f"Unexpected feature select strategy: {strategy!r}" msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo): return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_image_size_with_most_features(self) -> ImageSize:
return {"image": None} vision_encoder_info = self.get_vision_encoder_info()
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size() width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height) return ImageSize(width=width, height=height)
def _get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_image_tokens( return self.get_num_image_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
) )
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
@ -176,9 +179,10 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
) -> ProcessorInputs: ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
processor = self._get_hf_processor() processor = self.info.get_hf_processor()
image_token = processor.image_token image_token = processor.image_token
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = { mm_data = {
"image": "image":
@ -193,23 +197,13 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
) )
class LlavaProcessingMixin(BaseLlavaProcessingMixin): class LlavaProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_processor(self): def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor) return self.ctx.get_hf_processor(LlavaProcessor)
class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo): class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
pass
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
BaseMultiModalProcessor):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_profiling_info(self) -> BaseProfilingInfo:
raise NotImplementedError
# Copied from BaseMultiModalProcessor # Copied from BaseMultiModalProcessor
@abstractmethod @abstractmethod
@ -226,7 +220,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_config = self._get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int): def get_replacement(item_idx: int):
@ -237,7 +231,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
num_image_tokens = images.get_feature_size(item_idx) num_image_tokens = images.get_feature_size(item_idx)
else: else:
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens( num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
) )
@ -253,10 +247,8 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
] ]
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor): class LlavaMultiModalProcessor(
BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaProfilingInfo(self.ctx)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
@ -269,21 +261,14 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
) )
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin): class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_processor(self): def get_hf_processor(self):
return self.ctx.get_hf_processor(PixtralProcessor) return self.ctx.get_hf_processor(PixtralProcessor)
class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo): class PixtralHFMultiModalProcessor(
pass BaseMultiModalProcessor[PixtralHFProcessingInfo]):
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return PixtralHFProfilingInfo(self.ctx)
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -328,10 +313,10 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_config = self._get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
processor = self._get_hf_processor() processor = self.info.get_hf_processor()
image_token = processor.image_token image_token = processor.image_token
image_break_token = processor.image_break_token image_break_token = processor.image_break_token
image_end_token = processor.image_end_token image_end_token = processor.image_end_token
@ -363,26 +348,40 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
] ]
def _build_llava_or_pixtral_hf_info(
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
hf_config = ctx.get_hf_config(LlavaConfig)
if isinstance(hf_config.vision_config, PixtralVisionConfig):
return PixtralHFProcessingInfo(ctx)
return LlavaProcessingInfo(ctx)
def _build_llava_or_pixtral_hf_processor( def _build_llava_or_pixtral_hf_processor(
ctx: InputProcessingContext, info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True, enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor:
hf_config = ctx.get_hf_config(LlavaConfig) if isinstance(info, PixtralHFProcessingInfo):
if isinstance(hf_config.vision_config, PixtralVisionConfig):
return PixtralHFMultiModalProcessor( return PixtralHFMultiModalProcessor(
ctx, info,
dummy_inputs, # type: ignore
cache=cache, cache=cache,
enable_sanity_checks=enable_sanity_checks, enable_sanity_checks=enable_sanity_checks,
) )
return LlavaMultiModalProcessor( if isinstance(info, LlavaProcessingInfo):
ctx, return LlavaMultiModalProcessor(
cache=cache, info,
enable_sanity_checks=enable_sanity_checks, dummy_inputs, # type: ignore
) cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
raise NotImplementedError(type(info))
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
@ -460,7 +459,9 @@ def init_vision_tower_for_llava(
raise NotImplementedError(msg) raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor) @MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
info=_build_llava_or_pixtral_hf_info,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes # BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
@ -727,11 +728,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
hf_config = self._get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
# Assume that it doesn't depend on the image size # Assume that it doesn't depend on the image size
num_image_tokens = self._get_num_image_tokens( num_image_tokens = self.info.get_num_image_tokens(
image_width=-1, image_width=-1,
image_height=-1, image_height=-1,
) )
@ -796,6 +797,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
# To use this model, please use # To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` # `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MantisForConditionalGeneration(LlavaForConditionalGeneration): class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass pass

View File

@ -1,6 +1,7 @@
from abc import abstractmethod
from functools import cached_property from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional, from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union) Protocol, Set, Tuple, TypedDict, TypeVar, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -16,13 +17,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize from vllm.multimodal.parse import ImageSize
from vllm.multimodal.profiling import BaseProfilingInfo
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin, from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
BaseLlavaProfilingInfo, LlavaLikeConfig, LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava) LlavaMultiModalProjector, init_vision_tower_for_llava)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
@ -65,23 +65,23 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
image_grid_pinpoints: Final[list[list[int]]] image_grid_pinpoints: Final[list[list[int]]]
class LlavaNextProcessingMixin(BaseLlavaProcessingMixin): class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_config(self) -> LlavaNextLikeConfig: def get_hf_config(self) -> LlavaNextLikeConfig:
return self.ctx.get_hf_config(LlavaNextConfig) return self.ctx.get_hf_config(LlavaNextConfig)
def _get_hf_processor(self): def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextProcessor) return self.ctx.get_hf_processor(LlavaNextProcessor)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113 # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
def _get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
) -> int: ) -> int:
hf_config = self._get_hf_config() hf_config = self.get_hf_config()
vision_encoder_info = self._get_vision_encoder_info() vision_encoder_info = self.get_vision_encoder_info()
base_feature_size = self._apply_feature_select_strategy( base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy, hf_config.vision_feature_select_strategy,
@ -140,16 +140,13 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
return (unpadded_features, newline_features) return (unpadded_features, newline_features)
def get_image_size_with_most_features(self) -> ImageSize:
class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo): hf_config = self.get_hf_config()
def _get_image_size_with_most_features(self) -> ImageSize:
hf_config = self._get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints: for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width, feat_size = self.get_num_image_tokens(image_width=width,
image_height=height) image_height=height)
if feat_size > largest_feature_size: if feat_size > largest_feature_size:
largest_feature_size = feat_size largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width, largest_feature_pinpoint = ImageSize(width=width,
@ -161,11 +158,23 @@ class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
return largest_feature_pinpoint return largest_feature_pinpoint
class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin, _I = TypeVar("_I", bound=LlavaNextProcessingInfo)
BaseLlavaMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextProfilingInfo(self.ctx) class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
class LlavaNextMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
@ -179,7 +188,9 @@ class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
) )
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
info=LlavaNextProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):

View File

@ -17,12 +17,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin, BaseProcessingInfo, PromptReplacement)
PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -47,33 +46,52 @@ class LlavaNextVideoPixelInputs(TypedDict):
""" """
class LlavaNextVideoProcessingMixin(ProcessingMixin): class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(LlavaNextVideoConfig) return self.ctx.get_hf_config(LlavaNextVideoConfig)
def _get_vision_encoder_info(self): def get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config()) return get_vision_encoder_info(self.get_hf_config())
def _get_hf_processor(self): def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextVideoProcessor) return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len),
)
return {"video": max_video_tokens}
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_num_frame_tokens( def _get_num_frame_tokens(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
) -> int: ) -> int:
hf_config = self._get_hf_config() hf_config = self.get_hf_config()
spatial_pool_stride = hf_config.spatial_pool_stride spatial_pool_stride = hf_config.spatial_pool_stride
vision_encoder_info = self._get_vision_encoder_info() vision_encoder_info = self.get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length() patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens( def get_num_video_tokens(
self, self,
*, *,
image_width: int, image_width: int,
@ -87,37 +105,14 @@ class LlavaNextVideoProcessingMixin(ProcessingMixin):
return num_frame_tokens * num_frames return num_frame_tokens * num_frames
class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_video_tokens = self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
)
return {"video": max_video_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0 num_frames = 0
while True: while True:
next_num_frames = num_frames + 1 next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens( next_max_tokens = self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=next_num_frames, num_frames=next_num_frames,
@ -130,7 +125,7 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
return num_frames return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config() mm_config = self.ctx.get_mm_config()
max_videos = mm_config.limit_per_prompt.get("video", 1) max_videos = mm_config.limit_per_prompt.get("video", 1)
@ -138,6 +133,10 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
return max(max_total_frames // max(max_videos, 1), 1) return max(max_total_frames // max(max_videos, 1), 1)
class LlavaNextVideoDummyInputsBuilder(
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
@ -145,16 +144,20 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
) -> ProcessorInputs: ) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
processor = self._get_hf_processor() processor = self.info.get_hf_processor()
video_token = processor.video_token video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = { mm_data = {
"video": "video":
self._get_dummy_videos( self._get_dummy_videos(
width=target_width, width=target_width,
height=target_height, height=target_height,
num_frames=self._get_dummy_num_frames(seq_len), num_frames=target_num_frames,
num_videos=num_videos, num_videos=num_videos,
) )
} }
@ -165,11 +168,8 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
) )
class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin, class LlavaNextVideoMultiModalProcessor(
BaseMultiModalProcessor): BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextVideoProfilingInfo(self.ctx)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
@ -184,7 +184,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_config = self._get_hf_config() hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index video_token_id = hf_config.video_token_index
def get_replacement(item_idx: int): def get_replacement(item_idx: int):
@ -195,7 +195,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
num_video_tokens = videos.get_feature_size(item_idx) num_video_tokens = videos.get_feature_size(item_idx)
else: else:
image_size = videos.get_frame_size(item_idx) image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens( num_video_tokens = self.info.get_num_video_tokens(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx), num_frames=videos.get_num_frames(item_idx),
@ -269,7 +269,11 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(
LlavaNextVideoMultiModalProcessor,
info=LlavaNextVideoProcessingInfo,
dummy_inputs=LlavaNextVideoDummyInputsBuilder,
)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):

View File

@ -17,19 +17,20 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, NestedTensors)
VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement VideoProcessorItems)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs from vllm.multimodal.processing import PromptReplacement
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor, from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
LlavaNextProcessingMixin) LlavaNextProcessingInfo)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
@ -89,14 +90,23 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
video_token_index: Final[int] video_token_index: Final[int]
class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin): class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def _get_hf_config(self) -> LlavaOnevisionLikeConfig: def get_hf_config(self) -> LlavaOnevisionLikeConfig:
return self.ctx.get_hf_config(LlavaOnevisionConfig) return self.ctx.get_hf_config(LlavaOnevisionConfig)
def _get_hf_processor(self): def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaOnevisionProcessor) return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86 # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor # with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features( def _get_num_unpadded_features(
@ -141,16 +151,16 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
image_width: int, image_width: int,
image_height: int, image_height: int,
) -> int: ) -> int:
hf_config = self._get_hf_config() hf_config = self.get_hf_config()
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2) spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
vision_encoder_info = self._get_vision_encoder_info() vision_encoder_info = self.get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length() patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens( def get_num_video_tokens(
self, self,
*, *,
image_width: int, image_width: int,
@ -164,43 +174,14 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
return num_frame_tokens * num_frames + 1 # Newline token return num_frame_tokens * num_frames + 1 # Newline token
class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
BaseLlavaProfilingInfo):
def _get_image_size_with_most_features(self) -> ImageSize:
hf_config = self._get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width,
image_height=height)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self._get_max_image_tokens(),
"video": self._get_max_video_tokens(seq_len),
}
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0 num_frames = 0
while True: while True:
next_num_frames = num_frames + 1 next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens( next_max_tokens = self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=next_num_frames, num_frames=next_num_frames,
@ -213,12 +194,12 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
return num_frames return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config() mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1) max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1) max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens) max_image_tokens)
max_frames_per_video = min(max_total_frames // max(max_videos, 1), max_frames_per_video = min(max_total_frames // max(max_videos, 1),
@ -226,15 +207,19 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
return max(max_frames_per_video, 1) return max(max_frames_per_video, 1)
def _get_max_video_tokens(self, seq_len: int) -> int: def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_video_tokens( return self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len), num_frames=self.get_num_frames_with_most_features(seq_len),
) )
class LlavaOnevisionDummyInputsBuilder(
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
@ -243,10 +228,14 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
processor = self._get_hf_processor() processor = self.info.get_hf_processor()
image_token = processor.image_token image_token = processor.image_token
video_token = processor.video_token video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = { mm_data = {
"image": "image":
@ -257,7 +246,7 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
self._get_dummy_videos( self._get_dummy_videos(
width=target_width, width=target_width,
height=target_height, height=target_height,
num_frames=self._get_dummy_num_frames(seq_len), num_frames=target_num_frames,
num_videos=num_videos, num_videos=num_videos,
) )
} }
@ -268,11 +257,8 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
) )
class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin, class LlavaOnevisionMultiModalProcessor(
LlavaNextMultiModalProcessor): BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaOnevisionProfilingInfo(self.ctx)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
@ -303,7 +289,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
processor = self._get_hf_processor() processor = self.info.get_hf_processor()
video_token = processor.video_token video_token = processor.video_token
# LLaVA-OneVision processor doesn't support multiple videos # LLaVA-OneVision processor doesn't support multiple videos
@ -345,7 +331,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
out_mm_kwargs=out_mm_kwargs, out_mm_kwargs=out_mm_kwargs,
) )
hf_config = self._get_hf_config() hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index video_token_id = hf_config.video_token_index
def get_video_replacement(item_idx: int): def get_video_replacement(item_idx: int):
@ -356,7 +342,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
num_video_tokens = videos.get_feature_size(item_idx) num_video_tokens = videos.get_feature_size(item_idx)
else: else:
image_size = videos.get_frame_size(item_idx) image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens( num_video_tokens = self.info.get_num_video_tokens(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx), num_frames=videos.get_num_frames(item_idx),
@ -393,7 +379,10 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(
LlavaOnevisionMultiModalProcessor,
info=LlavaOnevisionProcessingInfo,
dummy_inputs=LlavaOnevisionDummyInputsBuilder)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):

View File

@ -34,13 +34,12 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin, BaseProcessingInfo,
PromptReplacement, BoundPromptReplacement,
_BoundPromptReplacement, PlaceholderInfo, PromptReplacement)
_PlaceholderInfo) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -302,9 +301,9 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline return image_features_hd_newline
class Phi3VProcessingMixin(ProcessingMixin): class Phi3VProcessingInfo(BaseProcessingInfo):
def _get_hf_processor( def get_hf_processor(
self, self,
*, *,
num_crops: Optional[int] = None, num_crops: Optional[int] = None,
@ -314,39 +313,42 @@ class Phi3VProcessingMixin(ProcessingMixin):
return self.ctx.get_hf_processor() return self.ctx.get_hf_processor()
def _get_num_image_tokens( def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
return {"image": max_image_tokens}
def get_num_image_tokens(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Optional[ProcessorMixin],
) -> int: ) -> int:
processor = self._get_hf_processor() if processor is None:
processor = self.get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width, width=image_width,
height=image_height, height=image_height,
) )
def get_image_size_with_most_features(self) -> ImageSize:
class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_image_tokens = self._get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
return {"image": max_image_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 16:1) # Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=8000, width=50) return ImageSize(height=8000, width=50)
class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
@ -354,7 +356,8 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
) -> ProcessorInputs: ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = { mm_data = {
"image": "image":
@ -363,7 +366,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
num_images=num_images) num_images=num_images)
} }
hf_processor = self._get_hf_processor() hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs( return ProcessorInputs(
@ -372,10 +375,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
) )
class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor): class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Phi3VProfilingInfo(self.ctx)
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -416,10 +416,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore image_tokens: list[str] = hf_processor.img_tokens # type: ignore
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
bos_token_id = tokenizer.bos_token_id bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int) assert isinstance(bos_token_id, int)
@ -431,9 +431,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
num_image_tokens = images.get_feature_size(item_idx) num_image_tokens = images.get_feature_size(item_idx)
else: else:
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens( num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
processor=hf_processor,
) )
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id] return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
@ -451,9 +452,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids, token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls, mm_prompt_repls=mm_prompt_repls,
@ -466,7 +467,7 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
token_ids = [token_ids[0], *token_ids[2:]] token_ids = [token_ids[0], *token_ids[2:]]
placeholders = { placeholders = {
modality: [ modality: [
_PlaceholderInfo( PlaceholderInfo(
modality=p.modality, modality=p.modality,
item_idx=p.item_idx, item_idx=p.item_idx,
start_idx=p.start_idx - 1, start_idx=p.start_idx - 1,
@ -499,7 +500,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
return result return result
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo,
dummy_inputs=Phi3VDummyInputsBuilder)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={

View File

@ -38,11 +38,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin, BaseProcessingInfo, PromptReplacement)
PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
@ -80,12 +80,12 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return feat_lengths, output_lengths return feat_lengths, output_lengths
class Qwen2AudioProcessingMixin(ProcessingMixin): class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig) return self.ctx.get_hf_config(Qwen2AudioConfig)
def _get_hf_processor( def get_hf_processor(
self, self,
*, *,
# Ignored in initialization # Ignored in initialization
@ -93,36 +93,37 @@ class Qwen2AudioProcessingMixin(ProcessingMixin):
) -> Qwen2AudioProcessor: ) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor) return self.ctx.get_hf_processor(Qwen2AudioProcessor)
def _get_feature_extractor( def get_feature_extractor(
self, self,
*, *,
# Ignored in initialization # Ignored in initialization
sampling_rate: Optional[int] = None, sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor: ) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate) hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
feature_extractor = hf_processor.feature_extractor # type: ignore feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor) assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor return feature_extractor
class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self._get_hf_config() hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1 max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths} return {"audio": max_output_lengths}
class Qwen2AudioDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate
@ -139,14 +140,11 @@ class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
) )
class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin, class Qwen2AudioMultiModalProcessor(
BaseMultiModalProcessor): BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Qwen2AudioProfilingInfo(self.ctx)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor( def _call_hf_processor(
@ -161,7 +159,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
if audios: if audios:
mm_data["audios"] = audios mm_data["audios"] = audios
feature_extractor = self._get_feature_extractor(**mm_kwargs) feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict( mm_kwargs = dict(
**mm_kwargs, **mm_kwargs,
sampling_rate=feature_extractor.sampling_rate, sampling_rate=feature_extractor.sampling_rate,
@ -194,7 +192,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_config = self._get_hf_config() hf_config = self.info.get_hf_config()
placeholder = hf_config.audio_token_index placeholder = hf_config.audio_token_index
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
@ -234,10 +232,13 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
# has already performed processing for multi-audio input when the input # has already performed processing for multi-audio input when the input
# audios are short (the corresponding placeholders may take up fewer # audios are short (the corresponding placeholders may take up fewer
# tokens than the number of audio items) # tokens than the number of audio items)
return not hasattr(self._get_hf_processor(), "audio_token") return not hasattr(self.info.get_hf_processor(), "audio_token")
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,
info=Qwen2AudioProcessingInfo,
dummy_inputs=Qwen2AudioDummyInputsBuilder)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):

View File

@ -57,11 +57,10 @@ from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs, MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem) NestedTensors, VideoItem)
from vllm.multimodal.parse import (ImageSize, ModalityDataItems, from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataParser) MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin, BaseProcessingInfo, PromptReplacement)
PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
@ -709,12 +708,12 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
return super()._parse_video_data(data) return super()._parse_video_data(data)
class Qwen2VLProcessingMixin(ProcessingMixin): class Qwen2VLProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2VLConfig) return self.ctx.get_hf_config(Qwen2VLConfig)
def _get_hf_processor( def get_hf_processor(
self, self,
*, *,
min_pixels: Optional[int] = None, min_pixels: Optional[int] = None,
@ -736,18 +735,27 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
return hf_processor return hf_processor
def _get_image_processor( def get_image_processor(
self, self,
*, *,
min_pixels: Optional[int] = None, min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None, max_pixels: Optional[int] = None,
): ):
hf_processor = self._get_hf_processor(min_pixels=min_pixels, hf_processor = self.get_hf_processor(min_pixels=min_pixels,
max_pixels=max_pixels) max_pixels=max_pixels)
image_processor = hf_processor.image_processor # type: ignore image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor) assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor return image_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
}
def _get_vision_info( def _get_vision_info(
self, self,
*, *,
@ -755,15 +763,17 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
image_height: int, image_height: int,
num_frames: int = 1, num_frames: int = 1,
do_resize: bool = True, do_resize: bool = True,
image_processor: Optional[Qwen2VLImageProcessor],
) -> tuple[ImageSize, int]: ) -> tuple[ImageSize, int]:
hf_config = self._get_hf_config() if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size temporal_patch_size = vision_config.temporal_patch_size
image_processor = self._get_image_processor()
if do_resize: if do_resize:
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=image_height, height=image_height,
@ -787,70 +797,65 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
return preprocessed_size, num_vision_tokens return preprocessed_size, num_vision_tokens
def _get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
image_processor: Optional[Qwen2VLImageProcessor],
) -> int: ) -> int:
_, num_image_tokens = self._get_vision_info( _, num_image_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
image_processor=image_processor,
) )
return num_image_tokens return num_image_tokens
def _get_num_video_tokens( def get_num_video_tokens(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
num_frames: int, num_frames: int,
image_processor: Optional[Qwen2VLImageProcessor],
) -> int: ) -> int:
_, num_video_tokens = self._get_vision_info( _, num_video_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
num_frames=num_frames, num_frames=num_frames,
image_processor=image_processor,
) )
return num_video_tokens return num_video_tokens
def get_image_size_with_most_features(self) -> ImageSize:
class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self._get_max_image_tokens(),
"video": self._get_max_video_tokens(seq_len),
}
def _get_image_size_with_most_features(self) -> ImageSize:
max_image_size, _ = self._get_vision_info( max_image_size, _ = self._get_vision_info(
image_width=9999999, image_width=9999999,
image_height=9999999, image_height=9999999,
image_processor=None,
) )
return max_image_size return max_image_size
def _get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_image_tokens( return self.get_num_image_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
image_processor=None,
) )
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0 num_frames = 0
while True: while True:
next_num_frames = num_frames + 1 next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens( next_max_tokens = self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=next_num_frames, num_frames=next_num_frames,
image_processor=None,
) )
if next_max_tokens > max_tokens: if next_max_tokens > max_tokens:
@ -860,12 +865,12 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
return num_frames return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config() mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1) max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1) max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens) max_image_tokens)
@ -877,15 +882,19 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
return num_frames return num_frames
def _get_max_video_tokens(self, seq_len: int) -> int: def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self._get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_video_tokens( return self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len), num_frames=self.get_num_frames_with_most_features(seq_len),
image_processor=None,
) )
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
@ -894,10 +903,14 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor() hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token video_token: str = hf_processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = { mm_data = {
"image": "image":
@ -908,7 +921,7 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
self._get_dummy_videos( self._get_dummy_videos(
width=target_width, width=target_width,
height=target_height, height=target_height,
num_frames=self._get_dummy_num_frames(seq_len), num_frames=target_num_frames,
num_videos=num_videos, num_videos=num_videos,
) )
} }
@ -919,11 +932,8 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
) )
class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin, class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
BaseMultiModalProcessor): ):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Qwen2VLProfilingInfo(self.ctx)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser() return Qwen2MultiModalDataParser()
@ -934,8 +944,9 @@ class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self._get_image_processor(**hf_processor_mm_kwargs) image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered # image_token and video_token registered
@ -991,7 +1002,9 @@ class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
) )
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
info=Qwen2VLProcessingInfo,
dummy_inputs=Qwen2VLDummyInputsBuilder)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP): SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {

View File

@ -24,11 +24,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin, BaseProcessingInfo, PromptReplacement)
PromptReplacement) from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
@ -59,9 +58,9 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
UltravoxAudioEmbeddingInputs] UltravoxAudioEmbeddingInputs]
class UltravoxProcessingMixin(ProcessingMixin): class UltravoxProcessingInfo(BaseProcessingInfo):
def _get_hf_processor( def get_hf_processor(
self, self,
*, *,
# Ignored in initialization # Ignored in initialization
@ -76,37 +75,38 @@ class UltravoxProcessingMixin(ProcessingMixin):
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
return hf_processor return hf_processor
def _get_feature_extractor( def get_feature_extractor(
self, self,
*, *,
# Ignored in initialization # Ignored in initialization
sampling_rate: Optional[int] = None, sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor: ) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate) hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
audio_processor = hf_processor.audio_processor # type: ignore audio_processor = hf_processor.audio_processor # type: ignore
feature_extractor = audio_processor.feature_extractor # type: ignore feature_extractor = audio_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor) assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor return feature_extractor
class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor() feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length * max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND) _AUDIO_TOKENS_PER_SECOND)
return {"audio": max_audio_tokens} return {"audio": max_audio_tokens}
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
):
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate
@ -123,14 +123,11 @@ class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
) )
class UltravoxMultiModalProcessor(UltravoxProcessingMixin, class UltravoxMultiModalProcessor(
BaseMultiModalProcessor): BaseMultiModalProcessor[UltravoxProcessingInfo]):
def _get_profiling_info(self) -> BaseProfilingInfo:
return UltravoxProfilingInfo(self.ctx)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor( def _call_hf_processor(
@ -141,7 +138,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
) -> BatchFeature: ) -> BatchFeature:
# Text-only input not supported in composite processor # Text-only input not supported in composite processor
if not mm_data: if not mm_data:
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode( prompt_ids = tokenizer.encode(
prompt, prompt,
@ -160,7 +157,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
feature_extractor = self._get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
mm_kwargs = dict( mm_kwargs = dict(
**mm_kwargs, **mm_kwargs,
sampling_rate=feature_extractor.sampling_rate, sampling_rate=feature_extractor.sampling_rate,
@ -208,7 +205,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
placeholder = hf_processor.audio_token_replacement # type: ignore placeholder = hf_processor.audio_token_replacement # type: ignore
def get_replacement_ultravox(item_idx: int): def get_replacement_ultravox(item_idx: int):
@ -342,7 +339,10 @@ class ModifiedWhisperEncoder(WhisperEncoder):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor,
info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder
)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(

View File

@ -4,12 +4,13 @@ from collections import defaultdict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union)
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from vllm import envs import vllm.envs as envs
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
@ -20,7 +21,9 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser from .parse import MultiModalDataItems, MultiModalDataParser
from .profiling import BaseProfilingInfo
if TYPE_CHECKING:
from .profiling import BaseDummyInputsBuilder
logger = init_logger(__name__) logger = init_logger(__name__)
@ -46,8 +49,8 @@ class PromptReplacement:
if it does not depend on the input. if it does not depend on the input.
""" """
def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement": def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
return _BoundPromptReplacement( return BoundPromptReplacement(
tokenizer=tokenizer, tokenizer=tokenizer,
modality=self.modality, modality=self.modality,
_target=self.target, _target=self.target,
@ -128,7 +131,7 @@ class _BoundPromptSequence:
@dataclass @dataclass
class _BoundPromptReplacement: class BoundPromptReplacement:
tokenizer: AnyTokenizer = field(repr=False) tokenizer: AnyTokenizer = field(repr=False)
modality: str modality: str
@ -207,7 +210,7 @@ def iter_token_matches(
@dataclass(repr=False) @dataclass(repr=False)
class _PromptReplacementMatch(ABC): class _PromptReplacementMatch(ABC):
prompt_repl: _BoundPromptReplacement prompt_repl: BoundPromptReplacement
@property @property
def modality(self) -> str: def modality(self) -> str:
@ -255,7 +258,7 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
@dataclass @dataclass
class _PlaceholderInfo: class PlaceholderInfo:
modality: str modality: str
item_idx: int item_idx: int
start_idx: int start_idx: int
@ -274,7 +277,7 @@ class _PlaceholderInfo:
def find_token_matches( def find_token_matches(
prompt: list[int], prompt: list[int],
prompt_repls: Sequence[_BoundPromptReplacement], prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]: ) -> list[_PromptReplacementTokenMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`.""" """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [ return [
@ -286,7 +289,7 @@ def find_token_matches(
def find_text_matches( def find_text_matches(
prompt: str, prompt: str,
prompt_repls: Sequence[_BoundPromptReplacement], prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]: ) -> list[_PromptReplacementTextMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`.""" """Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [ return [
@ -390,9 +393,9 @@ def replace_text_matches(
def _iter_modality_placeholders( def _iter_modality_placeholders(
prompt: list[int], prompt: list[int],
modality: str, modality: str,
modality_repls: Sequence[_BoundPromptReplacement], modality_repls: Sequence[BoundPromptReplacement],
modal_item_count: int, modal_item_count: int,
) -> Iterable[_PlaceholderInfo]: ) -> Iterable[PlaceholderInfo]:
if modal_item_count == 0: if modal_item_count == 0:
return return
@ -413,7 +416,7 @@ def _iter_modality_placeholders(
continue continue
if prompt[start_idx:end_idx] == repl_tokens: if prompt[start_idx:end_idx] == repl_tokens:
yield _PlaceholderInfo( yield PlaceholderInfo(
modality=modality, modality=modality,
item_idx=item_idx, item_idx=item_idx,
start_idx=start_idx, start_idx=start_idx,
@ -434,10 +437,10 @@ def _iter_modality_placeholders(
def _iter_placeholders( def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Iterable[_PlaceholderInfo]: ) -> Iterable[PlaceholderInfo]:
""" """
For each modality, yield each set of placeholder tokens found in For each modality, yield each set of placeholder tokens found in
:code:`prompt`. :code:`prompt`.
@ -455,10 +458,10 @@ def _iter_placeholders(
def find_mm_placeholders( def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int], prompt: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]: ) -> Mapping[str, list[PlaceholderInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
@ -524,29 +527,59 @@ class ProcessingCache:
self._cache.put(cache_key, output_kwargs) self._cache.put(cache_key, output_kwargs)
class ProcessingMixin: class BaseProcessingInfo:
""" """Base class containing information to perform processing."""
Contains helper functions to perform processing.
Not to be confused with :class:`transformers.ProcessorMixin`. def __init__(self, ctx: InputProcessingContext) -> None:
""" super().__init__()
ctx: InputProcessingContext
def _get_tokenizer(self) -> AnyTokenizer: self.ctx = ctx
@property
def model_id(self) -> str:
return self.ctx.model_config.model
def get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer return self.ctx.tokenizer
def _get_hf_config(self) -> PretrainedConfig: def get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config() return self.ctx.get_hf_config()
def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin: def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
""" """
Subclasses can override this method to handle Subclasses can override this method to handle
specific kwargs from model config or user inputs. specific kwargs from model config or user inputs.
""" """
return self.ctx.get_hf_processor(**kwargs) return self.ctx.get_hf_processor(**kwargs)
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
class BaseMultiModalProcessor(ProcessingMixin, ABC): A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
Abstract base class to process multi-modal inputs to be used in vLLM. Abstract base class to process multi-modal inputs to be used in vLLM.
@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
""" """
def __init__(self, def __init__(self,
ctx: InputProcessingContext, info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]",
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None: enable_sanity_checks: bool = True) -> None:
super().__init__() super().__init__()
self.ctx = ctx self.info = info
self.dummy_inputs = dummy_inputs
self.cache = cache self.cache = cache
self.enable_sanity_checks = enable_sanity_checks self.enable_sanity_checks = enable_sanity_checks
self.data_parser = self._get_data_parser() self.data_parser = self._get_data_parser()
self.profiling_info = self._get_profiling_info()
def __call__( def __call__(
self, self,
@ -585,13 +619,6 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
""" """
return MultiModalDataParser() return MultiModalDataParser()
def _get_profiling_info(self) -> BaseProfilingInfo:
"""
Get the profiling information to find the worst-case memory usage of
the model.
"""
raise NotImplementedError
def _to_mm_items( def _to_mm_items(
self, self,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
@ -602,7 +629,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
""" """
mm_items = self.data_parser.parse_mm_data(mm_data) mm_items = self.data_parser.parse_mm_data(mm_data)
mm_limits = self.ctx.get_mm_config().limit_per_prompt mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items(): for modality, items in mm_items.items():
limit = mm_limits.get(modality, 1) limit = mm_limits.get(modality, 1)
if len(items) > limit: if len(items) > limit:
@ -646,19 +673,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _find_mm_placeholders( def _find_mm_placeholders(
self, self,
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
new_token_ids: list[int], new_token_ids: list[int],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]: ) -> Mapping[str, list[PlaceholderInfo]]:
return find_mm_placeholders(mm_prompt_repls, new_token_ids, return find_mm_placeholders(mm_prompt_repls, new_token_ids,
mm_item_counts) mm_item_counts)
def _get_hf_mm_data( def _get_hf_mm_data(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]: ) -> tuple[Mapping[str, object], Mapping[str, object]]:
processor_data = dict[str, Any]() processor_data = dict[str, object]()
passthrough_data = dict[str, Any]() passthrough_data = dict[str, object]()
for items in mm_items.values(): for items in mm_items.values():
processor_data.update(items.get_processor_data()) processor_data.update(items.get_processor_data())
@ -678,8 +705,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
Call the HF processor on the prompt text and Call the HF processor on the prompt text and
associated multi-modal data. associated multi-modal data.
""" """
return self.ctx.call_hf_processor( return self.info.ctx.call_hf_processor(
self._get_hf_processor(**mm_kwargs), self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data), dict(text=prompt, **mm_data),
mm_kwargs, mm_kwargs,
) )
@ -738,8 +765,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding # Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text # multi-modal tokens to be in the prompt text
dummy_inputs = self.profiling_info.get_dummy_processor_inputs( dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
self.ctx.model_config.max_model_len, self.info.ctx.model_config.max_model_len,
mm_missing_counts, mm_missing_counts,
) )
@ -762,7 +789,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
caching the results and reusing cached results. caching the results and reusing cached results.
""" """
cache = self.cache cache = self.cache
model_id = self.ctx.model_config.model model_id = self.info.model_id
_, passthrough_data = self._get_hf_mm_data(mm_data_items) _, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data: if cache is None or passthrough_data:
@ -838,8 +865,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _bind_and_group_repls( def _bind_and_group_repls(
self, self,
prompt_repls: list[PromptReplacement], prompt_repls: list[PromptReplacement],
) -> dict[str, list[_BoundPromptReplacement]]: ) -> dict[str, list[BoundPromptReplacement]]:
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
@ -859,10 +886,10 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _apply_prompt_replacements( def _apply_prompt_replacements(
self, self,
token_ids: list[int], token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
mm_token_matches = { mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls) modality: find_token_matches(token_ids, prompt_repls)
@ -950,7 +977,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _validate_mm_placeholders( def _validate_mm_placeholders(
self, self,
mm_placeholders: Mapping[str, list[_PlaceholderInfo]], mm_placeholders: Mapping[str, list[PlaceholderInfo]],
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
*, *,
allow_missing: bool = False, allow_missing: bool = False,
@ -1001,7 +1028,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# instead of rehashing. # instead of rehashing.
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
model_id = self.ctx.model_config.model model_id = self.info.model_id
mm_hashes = { mm_hashes = {
modality: [ modality: [
MultiModalHasher.hash_kwargs(model_id=model_id, MultiModalHasher.hash_kwargs(model_id=model_id,
@ -1046,7 +1073,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
allow_missing=True, allow_missing=True,
) )
mm_missing_repls = dict[str, list[_BoundPromptReplacement]]() mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
for modality, missing_repl_count in mm_missing_repl_counts.items(): for modality, missing_repl_count in mm_missing_repl_counts.items():
if missing_repl_count == 0: if missing_repl_count == 0:
mm_missing_repls[modality] = [] mm_missing_repls[modality] = []
@ -1059,7 +1086,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# If HF processor already inserts placeholder tokens, # If HF processor already inserts placeholder tokens,
# there is no need for us to insert them # there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()): if all(len(repls) == 0 for repls in mm_missing_repls.items()):
tokenizer = self._get_tokenizer() tokenizer = self.info.get_tokenizer()
prompt_text = decode_tokens(tokenizer, prompt_ids) prompt_text = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders mm_placeholders = hf_mm_placeholders
else: else:
@ -1090,79 +1117,3 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges, mm_placeholders=mm_placeholder_ranges,
) )
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
profiling = self.profiling_info
processor_inputs = profiling.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
profiling = self.profiling_info
mm_counts = profiling.get_mm_limits()
mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
)

View File

@ -1,16 +1,18 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Generic, TypeVar
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from PIL import Image from PIL import Image
from vllm.inputs import InputProcessingContext import vllm.envs as envs
from vllm.inputs import DummyData
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import MultiModalDataDict from .inputs import MultiModalDataDict, MultiModalInputsV2
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__) logger = init_logger(__name__)
@ -23,39 +25,19 @@ class ProcessorInputs:
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
class BaseProfilingInfo(ABC): _I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
""" """
Abstract base class that provides the information necessary to profile Abstract base class that constructs the dummy data to profile
multi-modal models. multi-modal models.
""" """
def __init__(self, ctx: InputProcessingContext) -> None: def __init__(self, info: _I) -> None:
super().__init__() super().__init__()
self.ctx = ctx self.info = info
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
@ -64,8 +46,8 @@ class BaseProfilingInfo(ABC):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Build the multi-modal portion of the input which, after processing, Build the input which, after processing, results in
results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`. `self.info.get_mm_max_tokens_per_item()` placeholder tokens.
""" """
raise NotImplementedError raise NotImplementedError
@ -99,11 +81,33 @@ class BaseProfilingInfo(ABC):
video = np.zeros((num_frames, width, height, 3)) video = np.zeros((num_frames, width, height, 3))
return [video] * num_videos return [video] * num_videos
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.ctx.get_mm_config() class MultiModalProfiler(Generic[_I]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def __init__(
self,
processor: BaseMultiModalProcessor[_I],
) -> None:
super().__init__()
self.processor = processor
@property
def processing_info(self) -> BaseProcessingInfo:
return self.processor.info
@property
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs
def _get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt mm_limit_per_prompt = mm_config.limit_per_prompt
supported_mm_limits = self.get_supported_mm_limits() supported_mm_limits = self.processing_info.get_supported_mm_limits()
mm_limits = { mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1) modality: mm_limit_per_prompt.get(modality, 1)
@ -119,3 +123,81 @@ class BaseProfilingInfo(ABC):
f"at most {supported_limit} {modality} items.") f"at most {supported_limit} {modality} items.")
return mm_limits return mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.processor.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_counts = self._get_mm_limits()
info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
)

View File

@ -1,7 +1,8 @@
import functools import functools
from collections import UserDict from collections import UserDict
from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol, from dataclasses import dataclass
Sequence, Type, TypeVar) from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional,
Protocol, Sequence, Type, TypeVar)
import torch.nn as nn import torch.nn as nn
@ -14,7 +15,9 @@ from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import BaseMultiModalProcessor, ProcessingCache from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache)
from .profiling import BaseDummyInputsBuilder
from .utils import cached_get_tokenizer from .utils import cached_get_tokenizer
from .video import VideoPlugin from .video import VideoPlugin
@ -27,20 +30,59 @@ logger = init_logger(__name__)
MM_CACHE_SIZE = 256 MM_CACHE_SIZE = 256
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class MultiModalProcessorFactory(Protocol): class ProcessingInfoFactory(Protocol[_I_co]):
"""Constructs a :class:`MultiModalProcessor` instance from the context.""" """Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__( def __call__(
self, self,
ctx: InputProcessingContext, ctx: InputProcessingContext,
) -> _I_co:
...
class DummyInputsBuilderFactory(Protocol[_I]):
"""
Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
"""
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
...
class MultiModalProcessorFactory(Protocol[_I]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__(
self,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor[_I]:
... ...
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
info: ProcessingInfoFactory[_I]
processor: MultiModalProcessorFactory[_I]
dummy_inputs: DummyInputsBuilderFactory[_I]
def build_processor(
self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
):
info = self.info(ctx)
dummy_inputs_builder = self.dummy_inputs(info)
return self.processor(info, dummy_inputs_builder, cache=cache)
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]): class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
""" """
Wraps `_limits_by_model` for a more informative error message Wraps `_limits_by_model` for a more informative error message
@ -71,7 +113,7 @@ class MultiModalRegistry:
self._plugins = {p.get_data_key(): p for p in plugins} self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories = ClassRegistry[nn.Module, self._processor_factories = ClassRegistry[nn.Module,
MultiModalProcessorFactory]() _ProcessorFactories]()
# This is used for non-multimodal models # This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
@ -224,7 +266,7 @@ class MultiModalRegistry:
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer) processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
return processor.profiling_info.get_mm_max_tokens_per_item(seq_len) return processor.info.get_mm_max_tokens_per_item(seq_len)
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)
@ -315,7 +357,10 @@ class MultiModalRegistry:
def register_processor( def register_processor(
self, self,
factory: MultiModalProcessorFactory, processor: MultiModalProcessorFactory[_I],
*,
info: ProcessingInfoFactory[_I],
dummy_inputs: DummyInputsBuilderFactory[_I],
): ):
""" """
Register a multi-modal processor to a model class. The processor Register a multi-modal processor to a model class. The processor
@ -336,7 +381,11 @@ class MultiModalRegistry:
"registered to %s. It is overwritten by the new one.", "registered to %s. It is overwritten by the new one.",
model_cls, self) model_cls, self)
self._processor_factories[model_cls] = factory self._processor_factories[model_cls] = _ProcessorFactories(
info=info,
dummy_inputs=dummy_inputs,
processor=processor,
)
return model_cls return model_cls
@ -359,15 +408,15 @@ class MultiModalRegistry:
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
""" """
Create a multi-modal processor for a specific model and tokenizer. Create a multi-modal processor for a specific model and tokenizer.
""" """
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
processor_factory = self._processor_factories[model_cls] factories = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer) ctx = InputProcessingContext(model_config, tokenizer)
cache = (None if model_config.disable_mm_preprocessor_cache else cache = (None if model_config.disable_mm_preprocessor_cache else
self._processing_cache) self._processing_cache)
return processor_factory(ctx, cache=cache) return factories.build_processor(ctx, cache=cache)