[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-08 18:59:58 +08:00 committed by GitHub
parent f12141170a
commit 2a0596bc48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 833 additions and 760 deletions

View File

@ -4,24 +4,17 @@ from functools import partial
import pytest
from PIL import Image
from pqdm.threads import pqdm
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import cached_get_tokenizer
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_next():
from vllm.model_executor.models.llava_next import (
LlavaNextMultiModalProcessor)
return LlavaNextMultiModalProcessor
def _validate_image_prompt_replacements_one(
processor,
processor: BaseMultiModalProcessor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
@ -78,20 +71,17 @@ def _test_image_prompt_replacements(
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements_regression(
processor_for_llava_next,
model_id: str,
num_imgs: int,
):
def test_processor_prompt_replacements_regression(model_id, num_imgs):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_next(ctx)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression(
"Comment this out to run it manually.")
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all(
processor_for_llava_next,
model_id: str,
num_imgs: int,
):
def test_processor_prompt_replacements_all(model_id, num_imgs):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_next(ctx)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]()

View File

@ -4,24 +4,17 @@ from functools import partial
import pytest
from PIL import Image
from pqdm.threads import pqdm
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import cached_get_tokenizer
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_onevision():
from vllm.model_executor.models.llava_onevision import (
LlavaOnevisionMultiModalProcessor)
return LlavaOnevisionMultiModalProcessor
def _validate_image_prompt_replacements_one(
processor,
processor: BaseMultiModalProcessor,
num_imgs: int,
failed_size_excs: list[tuple[ImageSize, Exception]],
image_size: ImageSize,
@ -77,20 +70,17 @@ def _test_image_prompt_replacements(
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements_regression(
processor_for_llava_onevision,
model_id: str,
num_imgs: int,
):
def test_processor_prompt_replacements_regression(model_id, num_imgs):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_onevision(ctx)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
(488, 183), (2560, 1669)]
@ -111,20 +101,17 @@ def test_processor_prompt_replacements_regression(
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("num_imgs", [1])
def test_processor_prompt_replacements_all(
processor_for_llava_onevision,
model_id: str,
num_imgs: int,
):
def test_processor_prompt_replacements_all(model_id, num_imgs):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
processor = processor_for_llava_onevision(ctx)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
)
seen_aspect_ratios = set[float]()
image_sizes = list[ImageSize]()

View File

@ -1,21 +1,13 @@
"""Tests for phi3v's multimodal preprocessing kwargs."""
import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from .....conftest import _ImageAssets
from ....utils import build_model_context
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VMultiModalProcessor
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
# yapf: disable
@pytest.mark.parametrize(
@ -29,7 +21,6 @@ def processor_for_phi3v():
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
processor_for_phi3v,
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
@ -37,21 +28,26 @@ def test_processor_override(
num_imgs: int,
):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Avoid initializing CUDA early
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processor = processor_for_phi3v(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure we have the right number of placeholders per num_crops size

View File

@ -1,19 +1,12 @@
import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from .....conftest import _ImageAssets
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
return Qwen2VLMultiModalProcessor
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
# yapf: disable
@pytest.mark.parametrize(
@ -24,7 +17,6 @@ def processor_for_qwen2_vl():
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
processor_for_qwen2_vl,
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, object],
@ -39,18 +31,20 @@ def test_processor_override(
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
# Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processor = processor_for_qwen2_vl(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure we have the right number of placeholders per num_crops size
hf_processor = processor._get_hf_processor(**mm_processor_kwargs)
hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
image_token_id = tokenizer.convert_tokens_to_ids(hf_processor.image_token)
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values"].shape

View File

@ -10,12 +10,17 @@ from PIL import Image
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
_PlaceholderInfo, find_mm_placeholders,
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache,
PromptReplacement,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_text_matches,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
@ -431,7 +436,7 @@ def test_find_replace_tokens(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
{
"pattern_1": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=6,
@ -445,13 +450,13 @@ def test_find_replace_tokens(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
{
"pattern_1": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=1,
start_idx=5,
@ -459,7 +464,7 @@ def test_find_replace_tokens(
),
],
"pattern_3": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=7,
@ -472,13 +477,13 @@ def test_find_replace_tokens(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
{
"pattern_1": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_1",
item_idx=1,
start_idx=3,
@ -486,7 +491,7 @@ def test_find_replace_tokens(
),
],
"pattern_3": [
_PlaceholderInfo(
PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=6,
@ -577,19 +582,15 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
processor = MULTIMODAL_REGISTRY.create_processor(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
profiler = processor.profiling_info
profiler = MultiModalProfiler(processor)
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
profiler.get_supported_mm_limits = mock_supported_mm_limits
processor.info.get_supported_mm_limits = mock_supported_mm_limits
if is_valid:
exc_ctx = nullcontext()
@ -597,7 +598,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx:
profiler.get_mm_limits()
profiler.get_dummy_data(model_config.max_model_len)
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@ -620,16 +621,12 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
processor = MULTIMODAL_REGISTRY.create_processor(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
rng = np.random.RandomState(0)
image = _rand_img(rng, min_wh=128, max_wh=256)
if num_images == 0:
@ -681,9 +678,9 @@ def _test_processing_cache_correctness(
hf_overrides=hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
@ -691,8 +688,9 @@ def _test_processing_cache_correctness(
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)
baseline_processor = processor_factory(ctx, cache=None)
cached_processor = processor_factory(ctx, cache=cache)
baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
rng = np.random.RandomState(0)
@ -724,7 +722,7 @@ def _test_processing_cache_correctness(
}
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
prompt = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text

View File

@ -2,13 +2,17 @@ from typing import Optional
import torch
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaMultiModalProcessor)
from vllm.model_executor.models.llava import (LlavaDummyInputsBuilder,
LlavaForConditionalGeneration,
LlavaMultiModalProcessor,
LlavaProcessingInfo)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MyLlava(LlavaForConditionalGeneration):
def compute_logits(

View File

@ -7,7 +7,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputsV2
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_info_once, print_warning_once

View File

@ -323,6 +323,7 @@ class InputRegistry:
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.multimodal.utils import cached_get_tokenizer
if mm_registry.has_processor(model_config):
@ -331,7 +332,8 @@ class InputRegistry:
trust_remote_code=model_config.trust_remote_code,
)
processor = mm_registry.create_processor(model_config, tokenizer)
dummy_data = processor.get_dummy_data(seq_len)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(seq_len)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:

View File

@ -23,10 +23,10 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)
@ -445,33 +445,33 @@ def build_mm_projector(config: PretrainedConfig):
)
class AriaProcessingMixin(ProcessingMixin):
class AriaProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config()
def _get_vision_config(self) -> AriaVisionConfig:
return self._get_hf_config().vision_config
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
def get_vision_config(self) -> AriaVisionConfig:
return self.get_hf_config().vision_config
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
vision_config = self._get_vision_config()
vision_config = self.info.get_vision_config()
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
@ -483,7 +483,7 @@ class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
num_images=num_images)
}
hf_processor = self._get_hf_processor()
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token # type: ignore
return ProcessorInputs(
@ -492,10 +492,7 @@ class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
)
class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return AriaProfilingInfo(self.ctx)
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
def _get_mm_fields_config(
self,
@ -513,10 +510,10 @@ class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
num_image_tokens = self._get_num_image_tokens()
num_image_tokens = self.info.get_num_image_tokens()
return [
PromptReplacement(
@ -527,7 +524,9 @@ class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
]
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor,
info=AriaProcessingInfo,
dummy_inputs=AriaDummyInputsBuilder)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
"""
Aria model for conditional generation tasks.

View File

@ -17,10 +17,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel
@ -397,30 +397,30 @@ class Blip2QFormerModel(nn.Module):
return sequence_output
class Blip2ProcessingMixin(ProcessingMixin):
class Blip2ProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(Blip2Config)
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return hf_config.num_query_tokens
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
return hf_config.num_query_tokens
class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
@ -439,10 +439,7 @@ class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
)
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Blip2ProfilingInfo(self.ctx)
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
def _get_mm_fields_config(
self,
@ -460,7 +457,7 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
num_image_tokens = self._get_num_image_tokens()
num_image_tokens = self.info.get_num_image_tokens()
return [
PromptReplacement(
@ -491,7 +488,9 @@ class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo,
dummy_inputs=Blip2DummyInputsBuilder)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -30,10 +30,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once
@ -49,33 +49,34 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class ChameleonProcessingMixin(ProcessingMixin):
class ChameleonProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(ChameleonConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(ChameleonProcessor)
def _get_num_image_tokens(self) -> int:
processor = self._get_hf_processor()
return processor.image_seq_length
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()}
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
processor = self.get_hf_processor()
return processor.image_seq_length
class ChameleonDummyInputsBuilder(
BaseDummyInputsBuilder[ChameleonProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
config = self._get_hf_config()
config = self.info.get_hf_config()
width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)
@ -93,11 +94,8 @@ class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
)
class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return ChameleonProfilingInfo(self.ctx)
class ChameleonMultiModalProcessor(
BaseMultiModalProcessor[ChameleonProcessingInfo]):
def _get_mm_fields_config(
self,
@ -112,7 +110,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor(**hf_processor_mm_kwargs)
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
return [
PromptReplacement(
@ -120,7 +118,7 @@ class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * self._get_num_image_tokens(),
processor.image_token * self.info.get_num_image_tokens(),
processor.image_end_token,
]),
)
@ -916,7 +914,10 @@ class ChameleonModel(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(ChameleonMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
ChameleonMultiModalProcessor,
info=ChameleonProcessingInfo,
dummy_inputs=ChameleonDummyInputsBuilder)
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):

View File

@ -33,11 +33,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
@ -64,24 +64,38 @@ class FuyuImagePatchInputs(TypedDict):
"""
class FuyuProcessingMixin(ProcessingMixin):
class FuyuProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(FuyuConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(FuyuProcessor)
def _get_image_processor(self) -> FuyuImageProcessor:
return self._get_hf_processor().image_processor
def get_image_processor(self) -> FuyuImageProcessor:
return self.get_hf_processor().image_processor
def _get_image_feature_grid_size(
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_ncols, max_nrows = self.get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def get_image_feature_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
image_processor = self._get_image_processor()
image_processor = self.get_image_processor()
target_width = image_processor.size["width"]
target_height = image_processor.size["height"]
@ -97,34 +111,21 @@ class FuyuProcessingMixin(ProcessingMixin):
nrows = math.ceil(image_height / 30)
return ncols, nrows
class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_ncols, max_nrows = self._get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
image_processor = self._get_image_processor()
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"])
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
@ -140,10 +141,7 @@ class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
)
class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return FuyuProfilingInfo(self.ctx)
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
def _call_hf_processor(
self,
@ -156,7 +154,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
# Avoid warning from HF logger for text-only input
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
# Tokenizer won't add boa_token_id by default, we add it manually.
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
@ -196,10 +194,10 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int)
@ -207,7 +205,7 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = self._get_image_feature_grid_size(
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
@ -244,7 +242,9 @@ class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
info=FuyuProcessingInfo,
dummy_inputs=FuyuDummyInputsBuilder)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
@ -25,11 +25,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingCache,
ProcessingMixin, PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, ProcessingCache,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
@ -105,34 +105,23 @@ class LlavaLikeProcessor(Protocol):
image_token: Final[str]
class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
class BaseLlavaProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self) -> LlavaLikeConfig:
def get_hf_config(self) -> LlavaLikeConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config())
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
@abstractmethod
def _get_hf_processor(self) -> LlavaLikeProcessor:
def get_hf_processor(self) -> LlavaLikeProcessor:
raise NotImplementedError
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._get_vision_encoder_info()
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy(
self,
@ -147,28 +136,42 @@ class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_image_tokens(self) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_image_tokens(
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
self,
seq_len: int,
@ -176,9 +179,10 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
image_token = processor.image_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
@ -193,23 +197,13 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
)
class LlavaProcessingMixin(BaseLlavaProcessingMixin):
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)
class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo):
pass
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
BaseMultiModalProcessor):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_profiling_info(self) -> BaseProfilingInfo:
raise NotImplementedError
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
@ -226,7 +220,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int):
@ -237,7 +231,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens(
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
@ -253,10 +247,8 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
]
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaProfilingInfo(self.ctx)
class LlavaMultiModalProcessor(
BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
def _get_mm_fields_config(
self,
@ -269,21 +261,14 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
)
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin):
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(PixtralProcessor)
class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo):
pass
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return PixtralHFProfilingInfo(self.ctx)
class PixtralHFMultiModalProcessor(
BaseMultiModalProcessor[PixtralHFProcessingInfo]):
def _call_hf_processor(
self,
@ -328,10 +313,10 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
@ -363,26 +348,40 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
]
def _build_llava_or_pixtral_hf_info(
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
hf_config = ctx.get_hf_config(LlavaConfig)
if isinstance(hf_config.vision_config, PixtralVisionConfig):
return PixtralHFProcessingInfo(ctx)
return LlavaProcessingInfo(ctx)
def _build_llava_or_pixtral_hf_processor(
ctx: InputProcessingContext,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor:
hf_config = ctx.get_hf_config(LlavaConfig)
if isinstance(hf_config.vision_config, PixtralVisionConfig):
if isinstance(info, PixtralHFProcessingInfo):
return PixtralHFMultiModalProcessor(
ctx,
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
return LlavaMultiModalProcessor(
ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
if isinstance(info, LlavaProcessingInfo):
return LlavaMultiModalProcessor(
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
raise NotImplementedError(type(info))
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
@ -460,7 +459,9 @@ def init_vision_tower_for_llava(
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
info=_build_llava_or_pixtral_hf_info,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
@ -727,11 +728,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
# Assume that it doesn't depend on the image size
num_image_tokens = self._get_num_image_tokens(
num_image_tokens = self.info.get_num_image_tokens(
image_width=-1,
image_height=-1,
)
@ -796,6 +797,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass

View File

@ -1,6 +1,7 @@
from abc import abstractmethod
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
@ -16,13 +17,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.profiling import BaseProfilingInfo
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin,
BaseLlavaProfilingInfo, LlavaLikeConfig,
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
@ -65,23 +65,23 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
image_grid_pinpoints: Final[list[list[int]]]
class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_config(self) -> LlavaNextLikeConfig:
def get_hf_config(self) -> LlavaNextLikeConfig:
return self.ctx.get_hf_config(LlavaNextConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextProcessor)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
def _get_num_image_tokens(
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._get_vision_encoder_info()
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
@ -140,16 +140,13 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
return (unpadded_features, newline_features)
class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
def _get_image_size_with_most_features(self) -> ImageSize:
hf_config = self._get_hf_config()
def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width,
image_height=height)
feat_size = self.get_num_image_tokens(image_width=width,
image_height=height)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
@ -161,11 +158,23 @@ class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
return largest_feature_pinpoint
class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
BaseLlavaMultiModalProcessor):
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextProfilingInfo(self.ctx)
class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
class LlavaNextMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
def _get_mm_fields_config(
self,
@ -179,7 +188,9 @@ class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
info=LlavaNextProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):

View File

@ -17,12 +17,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems,
VideoProcessorItems)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -47,33 +46,52 @@ class LlavaNextVideoPixelInputs(TypedDict):
"""
class LlavaNextVideoProcessingMixin(ProcessingMixin):
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(LlavaNextVideoConfig)
def _get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config())
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len),
)
return {"video": max_video_tokens}
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_num_frame_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
hf_config = self.get_hf_config()
spatial_pool_stride = hf_config.spatial_pool_stride
vision_encoder_info = self._get_vision_encoder_info()
vision_encoder_info = self.get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens(
def get_num_video_tokens(
self,
*,
image_width: int,
@ -87,37 +105,14 @@ class LlavaNextVideoProcessingMixin(ProcessingMixin):
return num_frame_tokens * num_frames
class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_video_tokens = self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
)
return {"video": max_video_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
@ -130,7 +125,7 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_videos = mm_config.limit_per_prompt.get("video", 1)
@ -138,6 +133,10 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
return max(max_total_frames // max(max_videos, 1), 1)
class LlavaNextVideoDummyInputsBuilder(
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
@ -145,16 +144,20 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = {
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=target_num_frames,
num_videos=num_videos,
)
}
@ -165,11 +168,8 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
)
class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextVideoProfilingInfo(self.ctx)
class LlavaNextVideoMultiModalProcessor(
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
def _get_mm_fields_config(
self,
@ -184,7 +184,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index
def get_replacement(item_idx: int):
@ -195,7 +195,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
num_video_tokens = videos.get_feature_size(item_idx)
else:
image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
num_video_tokens = self.info.get_num_video_tokens(
image_width=image_size.width,
image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx),
@ -269,7 +269,11 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
LlavaNextVideoMultiModalProcessor,
info=LlavaNextVideoProcessingInfo,
dummy_inputs=LlavaNextVideoDummyInputsBuilder,
)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):

View File

@ -17,19 +17,20 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava
from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor,
LlavaNextProcessingMixin)
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
LlavaNextProcessingInfo)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@ -89,14 +90,23 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
video_token_index: Final[int]
class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def _get_hf_config(self) -> LlavaOnevisionLikeConfig:
def get_hf_config(self) -> LlavaOnevisionLikeConfig:
return self.ctx.get_hf_config(LlavaOnevisionConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features(
@ -141,16 +151,16 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
hf_config = self.get_hf_config()
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
vision_encoder_info = self._get_vision_encoder_info()
vision_encoder_info = self.get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens(
def get_num_video_tokens(
self,
*,
image_width: int,
@ -164,43 +174,14 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
return num_frame_tokens * num_frames + 1 # Newline token
class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
BaseLlavaProfilingInfo):
def _get_image_size_with_most_features(self) -> ImageSize:
hf_config = self._get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width,
image_height=height)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self._get_max_image_tokens(),
"video": self._get_max_video_tokens(seq_len),
}
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
@ -213,12 +194,12 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
@ -226,15 +207,19 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
return max(max_frames_per_video, 1)
def _get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_video_tokens(
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=self.get_num_frames_with_most_features(seq_len),
)
class LlavaOnevisionDummyInputsBuilder(
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
@ -243,10 +228,14 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
image_token = processor.image_token
video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = {
"image":
@ -257,7 +246,7 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=target_num_frames,
num_videos=num_videos,
)
}
@ -268,11 +257,8 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
)
class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
LlavaNextMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaOnevisionProfilingInfo(self.ctx)
class LlavaOnevisionMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
def _get_mm_fields_config(
self,
@ -303,7 +289,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
mm_kwargs=mm_kwargs,
)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
video_token = processor.video_token
# LLaVA-OneVision processor doesn't support multiple videos
@ -345,7 +331,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
out_mm_kwargs=out_mm_kwargs,
)
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index
def get_video_replacement(item_idx: int):
@ -356,7 +342,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
num_video_tokens = videos.get_feature_size(item_idx)
else:
image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
num_video_tokens = self.info.get_num_video_tokens(
image_width=image_size.width,
image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx),
@ -393,7 +379,10 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
LlavaOnevisionMultiModalProcessor,
info=LlavaOnevisionProcessingInfo,
dummy_inputs=LlavaOnevisionDummyInputsBuilder)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):

View File

@ -34,13 +34,12 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo,
BoundPromptReplacement,
PlaceholderInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -302,9 +301,9 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
class Phi3VProcessingMixin(ProcessingMixin):
class Phi3VProcessingInfo(BaseProcessingInfo):
def _get_hf_processor(
def get_hf_processor(
self,
*,
num_crops: Optional[int] = None,
@ -314,39 +313,42 @@ class Phi3VProcessingMixin(ProcessingMixin):
return self.ctx.get_hf_processor()
def _get_num_image_tokens(
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
return {"image": max_image_tokens}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[ProcessorMixin],
) -> int:
processor = self._get_hf_processor()
if processor is None:
processor = self.get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
height=image_height,
)
class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_image_tokens = self._get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
return {"image": max_image_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=8000, width=50)
class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
@ -354,7 +356,8 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
@ -363,7 +366,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
num_images=num_images)
}
hf_processor = self._get_hf_processor()
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
@ -372,10 +375,7 @@ class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
)
class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Phi3VProfilingInfo(self.ctx)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
def _call_hf_processor(
self,
@ -416,10 +416,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
@ -431,9 +431,10 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens(
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return [_IMAGE_TOKEN_ID] * num_image_tokens + [bos_token_id]
@ -451,9 +452,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
def _apply_prompt_replacements(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls,
@ -466,7 +467,7 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = {
modality: [
_PlaceholderInfo(
PlaceholderInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
@ -499,7 +500,9 @@ class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
return result
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
info=Phi3VProcessingInfo,
dummy_inputs=Phi3VDummyInputsBuilder)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={

View File

@ -38,11 +38,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
@ -80,12 +80,12 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return feat_lengths, output_lengths
class Qwen2AudioProcessingMixin(ProcessingMixin):
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig)
def _get_hf_processor(
def get_hf_processor(
self,
*,
# Ignored in initialization
@ -93,36 +93,37 @@ class Qwen2AudioProcessingMixin(ProcessingMixin):
) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
def _get_feature_extractor(
def get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self._get_hf_config()
hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
class Qwen2AudioDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
@ -139,14 +140,11 @@ class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
)
class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Qwen2AudioProfilingInfo(self.ctx)
class Qwen2AudioMultiModalProcessor(
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
@ -161,7 +159,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
if audios:
mm_data["audios"] = audios
feature_extractor = self._get_feature_extractor(**mm_kwargs)
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
@ -194,7 +192,7 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
placeholder = hf_config.audio_token_index
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
@ -234,10 +232,13 @@ class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
# has already performed processing for multi-audio input when the input
# audios are short (the corresponding placeholders may take up fewer
# tokens than the number of audio items)
return not hasattr(self._get_hf_processor(), "audio_token")
return not hasattr(self.info.get_hf_processor(), "audio_token")
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,
info=Qwen2AudioProcessingInfo,
dummy_inputs=Qwen2AudioDummyInputsBuilder)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):

View File

@ -57,11 +57,10 @@ from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem)
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataParser)
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
@ -709,12 +708,12 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
return super()._parse_video_data(data)
class Qwen2VLProcessingMixin(ProcessingMixin):
class Qwen2VLProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2VLConfig)
def _get_hf_processor(
def get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
@ -736,18 +735,27 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
return hf_processor
def _get_image_processor(
def get_image_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
):
hf_processor = self._get_hf_processor(min_pixels=min_pixels,
max_pixels=max_pixels)
hf_processor = self.get_hf_processor(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
}
def _get_vision_info(
self,
*,
@ -755,15 +763,17 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
image_processor: Optional[Qwen2VLImageProcessor],
) -> tuple[ImageSize, int]:
hf_config = self._get_hf_config()
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
image_processor = self._get_image_processor()
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
@ -787,70 +797,65 @@ class Qwen2VLProcessingMixin(ProcessingMixin):
return preprocessed_size, num_vision_tokens
def _get_num_image_tokens(
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
image_processor: Optional[Qwen2VLImageProcessor],
) -> int:
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
image_processor=image_processor,
)
return num_image_tokens
def _get_num_video_tokens(
def get_num_video_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int,
image_processor: Optional[Qwen2VLImageProcessor],
) -> int:
_, num_video_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=num_frames,
image_processor=image_processor,
)
return num_video_tokens
class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self._get_max_image_tokens(),
"video": self._get_max_video_tokens(seq_len),
}
def _get_image_size_with_most_features(self) -> ImageSize:
def get_image_size_with_most_features(self) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
image_processor=None,
)
return max_image_size
def _get_max_image_tokens(self) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_image_tokens(
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
image_processor=None,
)
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
image_processor=None,
)
if next_max_tokens > max_tokens:
@ -860,12 +865,12 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
@ -877,15 +882,19 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
return num_frames
def _get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_video_tokens(
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=self.get_num_frames_with_most_features(seq_len),
image_processor=None,
)
class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
@ -894,10 +903,14 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor()
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = {
"image":
@ -908,7 +921,7 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=target_num_frames,
num_videos=num_videos,
)
}
@ -919,11 +932,8 @@ class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
)
class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Qwen2VLProfilingInfo(self.ctx)
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
@ -934,8 +944,9 @@ class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self._get_image_processor(**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
@ -991,7 +1002,9 @@ class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,
info=Qwen2VLProcessingInfo,
dummy_inputs=Qwen2VLDummyInputsBuilder)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
packed_modules_mapping = {

View File

@ -24,11 +24,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
@ -59,9 +58,9 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
UltravoxAudioEmbeddingInputs]
class UltravoxProcessingMixin(ProcessingMixin):
class UltravoxProcessingInfo(BaseProcessingInfo):
def _get_hf_processor(
def get_hf_processor(
self,
*,
# Ignored in initialization
@ -76,37 +75,38 @@ class UltravoxProcessingMixin(ProcessingMixin):
hf_processor.audio_token_replacement = _AUDIO_PLACEHOLDER_OVERRIDE
return hf_processor
def _get_feature_extractor(
def get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
audio_processor = hf_processor.audio_processor # type: ignore
feature_extractor = audio_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)
return {"audio": max_audio_tokens}
class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
@ -123,14 +123,11 @@ class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
)
class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return UltravoxProfilingInfo(self.ctx)
class UltravoxMultiModalProcessor(
BaseMultiModalProcessor[UltravoxProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
@ -141,7 +138,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data:
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(
prompt,
@ -160,7 +157,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
mm_kwargs=mm_kwargs,
)
feature_extractor = self._get_feature_extractor()
feature_extractor = self.info.get_feature_extractor()
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
@ -208,7 +205,7 @@ class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
placeholder = hf_processor.audio_token_replacement # type: ignore
def get_replacement_ultravox(item_idx: int):
@ -342,7 +339,10 @@ class ModifiedWhisperEncoder(WhisperEncoder):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(UltravoxMultiModalProcessor,
info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder
)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(

View File

@ -4,12 +4,13 @@ from collections import defaultdict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union)
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from vllm import envs
from vllm.inputs import DummyData, InputProcessingContext
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
@ -20,7 +21,9 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
from .profiling import BaseProfilingInfo
if TYPE_CHECKING:
from .profiling import BaseDummyInputsBuilder
logger = init_logger(__name__)
@ -46,8 +49,8 @@ class PromptReplacement:
if it does not depend on the input.
"""
def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
return _BoundPromptReplacement(
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
return BoundPromptReplacement(
tokenizer=tokenizer,
modality=self.modality,
_target=self.target,
@ -128,7 +131,7 @@ class _BoundPromptSequence:
@dataclass
class _BoundPromptReplacement:
class BoundPromptReplacement:
tokenizer: AnyTokenizer = field(repr=False)
modality: str
@ -207,7 +210,7 @@ def iter_token_matches(
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
prompt_repl: _BoundPromptReplacement
prompt_repl: BoundPromptReplacement
@property
def modality(self) -> str:
@ -255,7 +258,7 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
@dataclass
class _PlaceholderInfo:
class PlaceholderInfo:
modality: str
item_idx: int
start_idx: int
@ -274,7 +277,7 @@ class _PlaceholderInfo:
def find_token_matches(
prompt: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
@ -286,7 +289,7 @@ def find_token_matches(
def find_text_matches(
prompt: str,
prompt_repls: Sequence[_BoundPromptReplacement],
prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
@ -390,9 +393,9 @@ def replace_text_matches(
def _iter_modality_placeholders(
prompt: list[int],
modality: str,
modality_repls: Sequence[_BoundPromptReplacement],
modality_repls: Sequence[BoundPromptReplacement],
modal_item_count: int,
) -> Iterable[_PlaceholderInfo]:
) -> Iterable[PlaceholderInfo]:
if modal_item_count == 0:
return
@ -413,7 +416,7 @@ def _iter_modality_placeholders(
continue
if prompt[start_idx:end_idx] == repl_tokens:
yield _PlaceholderInfo(
yield PlaceholderInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx,
@ -434,10 +437,10 @@ def _iter_modality_placeholders(
def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Iterable[_PlaceholderInfo]:
) -> Iterable[PlaceholderInfo]:
"""
For each modality, yield each set of placeholder tokens found in
:code:`prompt`.
@ -455,10 +458,10 @@ def _iter_placeholders(
def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]:
) -> Mapping[str, list[PlaceholderInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
return dict(full_groupby_modality(it))
@ -524,29 +527,59 @@ class ProcessingCache:
self._cache.put(cache_key, output_kwargs)
class ProcessingMixin:
"""
Contains helper functions to perform processing.
class BaseProcessingInfo:
"""Base class containing information to perform processing."""
Not to be confused with :class:`transformers.ProcessorMixin`.
"""
ctx: InputProcessingContext
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__()
def _get_tokenizer(self) -> AnyTokenizer:
self.ctx = ctx
@property
def model_id(self) -> str:
return self.ctx.model_config.model
def get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer
def _get_hf_config(self) -> PretrainedConfig:
def get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config()
def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
"""
Subclasses can override this method to handle
specific kwargs from model config or user inputs.
"""
return self.ctx.get_hf_processor(**kwargs)
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
class BaseMultiModalProcessor(ProcessingMixin, ABC):
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
Abstract base class to process multi-modal inputs to be used in vLLM.
@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
def __init__(self,
ctx: InputProcessingContext,
info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__()
self.ctx = ctx
self.info = info
self.dummy_inputs = dummy_inputs
self.cache = cache
self.enable_sanity_checks = enable_sanity_checks
self.data_parser = self._get_data_parser()
self.profiling_info = self._get_profiling_info()
def __call__(
self,
@ -585,13 +619,6 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
return MultiModalDataParser()
def _get_profiling_info(self) -> BaseProfilingInfo:
"""
Get the profiling information to find the worst-case memory usage of
the model.
"""
raise NotImplementedError
def _to_mm_items(
self,
mm_data: MultiModalDataDict,
@ -602,7 +629,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_limits = self.ctx.get_mm_config().limit_per_prompt
mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items():
limit = mm_limits.get(modality, 1)
if len(items) > limit:
@ -646,19 +673,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _find_mm_placeholders(
self,
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[_PlaceholderInfo]]:
) -> Mapping[str, list[PlaceholderInfo]]:
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
mm_item_counts)
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
) -> tuple[Mapping[str, object], Mapping[str, object]]:
processor_data = dict[str, object]()
passthrough_data = dict[str, object]()
for items in mm_items.values():
processor_data.update(items.get_processor_data())
@ -678,8 +705,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
Call the HF processor on the prompt text and
associated multi-modal data.
"""
return self.ctx.call_hf_processor(
self._get_hf_processor(**mm_kwargs),
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
mm_kwargs,
)
@ -738,8 +765,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text
dummy_inputs = self.profiling_info.get_dummy_processor_inputs(
self.ctx.model_config.max_model_len,
dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
self.info.ctx.model_config.max_model_len,
mm_missing_counts,
)
@ -762,7 +789,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
caching the results and reusing cached results.
"""
cache = self.cache
model_id = self.ctx.model_config.model
model_id = self.info.model_id
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data:
@ -838,8 +865,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _bind_and_group_repls(
self,
prompt_repls: list[PromptReplacement],
) -> dict[str, list[_BoundPromptReplacement]]:
tokenizer = self._get_tokenizer()
) -> dict[str, list[BoundPromptReplacement]]:
tokenizer = self.info.get_tokenizer()
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
return dict(full_groupby_modality(it))
@ -859,10 +886,10 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _apply_prompt_replacements(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]:
tokenizer = self._get_tokenizer()
) -> tuple[list[int], str, Mapping[str, list[PlaceholderInfo]]]:
tokenizer = self.info.get_tokenizer()
mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls)
@ -950,7 +977,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[_PlaceholderInfo]],
mm_placeholders: Mapping[str, list[PlaceholderInfo]],
mm_item_counts: Mapping[str, int],
*,
allow_missing: bool = False,
@ -1001,7 +1028,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# instead of rehashing.
if envs.VLLM_USE_V1:
model_id = self.ctx.model_config.model
model_id = self.info.model_id
mm_hashes = {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
@ -1046,7 +1073,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
allow_missing=True,
)
mm_missing_repls = dict[str, list[_BoundPromptReplacement]]()
mm_missing_repls = dict[str, list[BoundPromptReplacement]]()
for modality, missing_repl_count in mm_missing_repl_counts.items():
if missing_repl_count == 0:
mm_missing_repls[modality] = []
@ -1059,7 +1086,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()):
tokenizer = self._get_tokenizer()
tokenizer = self.info.get_tokenizer()
prompt_text = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders
else:
@ -1090,79 +1117,3 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges,
)
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
profiling = self.profiling_info
processor_inputs = profiling.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
profiling = self.profiling_info
mm_counts = profiling.get_mm_limits()
mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
)

View File

@ -1,16 +1,18 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Optional
from typing import Generic, TypeVar
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm.inputs import InputProcessingContext
import vllm.envs as envs
from vllm.inputs import DummyData
from vllm.logger import init_logger
from .inputs import MultiModalDataDict
from .inputs import MultiModalDataDict, MultiModalInputsV2
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
logger = init_logger(__name__)
@ -23,39 +25,19 @@ class ProcessorInputs:
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
class BaseProfilingInfo(ABC):
_I = TypeVar("_I", bound=BaseProcessingInfo)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
Abstract base class that provides the information necessary to profile
Abstract base class that constructs the dummy data to profile
multi-modal models.
"""
def __init__(self, ctx: InputProcessingContext) -> None:
def __init__(self, info: _I) -> None:
super().__init__()
self.ctx = ctx
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
self.info = info
@abstractmethod
def get_dummy_processor_inputs(
@ -64,8 +46,8 @@ class BaseProfilingInfo(ABC):
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
"""
Build the multi-modal portion of the input which, after processing,
results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`.
Build the input which, after processing, results in
`self.info.get_mm_max_tokens_per_item()` placeholder tokens.
"""
raise NotImplementedError
@ -99,11 +81,33 @@ class BaseProfilingInfo(ABC):
video = np.zeros((num_frames, width, height, 3))
return [video] * num_videos
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.ctx.get_mm_config()
class MultiModalProfiler(Generic[_I]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def __init__(
self,
processor: BaseMultiModalProcessor[_I],
) -> None:
super().__init__()
self.processor = processor
@property
def processing_info(self) -> BaseProcessingInfo:
return self.processor.info
@property
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs
def _get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt
supported_mm_limits = self.get_supported_mm_limits()
supported_mm_limits = self.processing_info.get_supported_mm_limits()
mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1)
@ -119,3 +123,81 @@ class BaseProfilingInfo(ABC):
f"at most {supported_limit} {modality} items.")
return mm_limits
def _get_dummy_mm_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalInputsV2:
factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.processor.apply(
prompt_text=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)
def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData
mm_counts = self._get_mm_limits()
info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(
"The keys returned by `get_supported_mm_limits`"
f"({set(mm_counts.keys())}) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
expected_placeholders_by_modality = {
modality: mm_max_tokens_per_item[modality] * mm_counts[modality]
for modality in placeholders_by_modality
}
if total_placeholders_by_modality != expected_placeholders_by_modality:
raise AssertionError(
f"The processed dummy data has a total of "
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
total_len = len(prompt_token_ids)
# V0 does not support chunked prefill.
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
)

View File

@ -1,7 +1,8 @@
import functools
from collections import UserDict
from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol,
Sequence, Type, TypeVar)
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional,
Protocol, Sequence, Type, TypeVar)
import torch.nn as nn
@ -14,7 +15,9 @@ from .audio import AudioPlugin
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
from .image import ImagePlugin
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import BaseMultiModalProcessor, ProcessingCache
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache)
from .profiling import BaseDummyInputsBuilder
from .utils import cached_get_tokenizer
from .video import VideoPlugin
@ -27,20 +30,59 @@ logger = init_logger(__name__)
MM_CACHE_SIZE = 256
N = TypeVar("N", bound=Type[nn.Module])
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
class MultiModalProcessorFactory(Protocol):
class ProcessingInfoFactory(Protocol[_I_co]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__(
self,
ctx: InputProcessingContext,
) -> _I_co:
...
class DummyInputsBuilderFactory(Protocol[_I]):
"""
Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
"""
def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
...
class MultiModalProcessorFactory(Protocol[_I]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def __call__(
self,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
) -> BaseMultiModalProcessor:
) -> BaseMultiModalProcessor[_I]:
...
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
info: ProcessingInfoFactory[_I]
processor: MultiModalProcessorFactory[_I]
dummy_inputs: DummyInputsBuilderFactory[_I]
def build_processor(
self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
):
info = self.info(ctx)
dummy_inputs_builder = self.dummy_inputs(info)
return self.processor(info, dummy_inputs_builder, cache=cache)
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
"""
Wraps `_limits_by_model` for a more informative error message
@ -71,7 +113,7 @@ class MultiModalRegistry:
self._plugins = {p.get_data_key(): p for p in plugins}
self._processor_factories = ClassRegistry[nn.Module,
MultiModalProcessorFactory]()
_ProcessorFactories]()
# This is used for non-multimodal models
self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}
@ -224,7 +266,7 @@ class MultiModalRegistry:
tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len
return processor.profiling_info.get_mm_max_tokens_per_item(seq_len)
return processor.info.get_mm_max_tokens_per_item(seq_len)
return {
key: plugin.get_max_multimodal_tokens(model_config)
@ -315,7 +357,10 @@ class MultiModalRegistry:
def register_processor(
self,
factory: MultiModalProcessorFactory,
processor: MultiModalProcessorFactory[_I],
*,
info: ProcessingInfoFactory[_I],
dummy_inputs: DummyInputsBuilderFactory[_I],
):
"""
Register a multi-modal processor to a model class. The processor
@ -336,7 +381,11 @@ class MultiModalRegistry:
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._processor_factories[model_cls] = factory
self._processor_factories[model_cls] = _ProcessorFactories(
info=info,
dummy_inputs=dummy_inputs,
processor=processor,
)
return model_cls
@ -359,15 +408,15 @@ class MultiModalRegistry:
self,
model_config: "ModelConfig",
tokenizer: AnyTokenizer,
) -> BaseMultiModalProcessor:
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""
model_cls = self._get_model_cls(model_config)
processor_factory = self._processor_factories[model_cls]
factories = self._processor_factories[model_cls]
ctx = InputProcessingContext(model_config, tokenizer)
cache = (None if model_config.disable_mm_preprocessor_cache else
self._processing_cache)
return processor_factory(ctx, cache=cache)
return factories.build_processor(ctx, cache=cache)