[VLM] Separate out profiling-related logic (#11746)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-06 16:02:21 +08:00 committed by GitHub
parent 2a622d704a
commit 996357e480
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1036 additions and 739 deletions

View File

@ -586,9 +586,10 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
) )
processor = processor_factory(ctx, cache=None) processor = processor_factory(ctx, cache=None)
profiler = processor.profiling_info
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported}) mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
processor.get_supported_mm_limits = mock_supported_mm_limits profiler.get_supported_mm_limits = mock_supported_mm_limits
if is_valid: if is_valid:
exc_ctx = nullcontext() exc_ctx = nullcontext()
@ -596,7 +597,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = pytest.raises(ValueError, match="this model only supports") exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx: with exc_ctx:
processor._get_and_validate_dummy_mm_counts() profiler.get_mm_limits()
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@ -723,7 +724,7 @@ def _test_processing_cache_correctness(
} }
mm_counts = {k: len(vs) for k, vs in mm_data.items()} mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor._get_dummy_processor_inputs( prompt = baseline_processor.profiling_info.get_dummy_processor_inputs(
model_config.max_model_len, model_config.max_model_len,
mm_counts, mm_counts,
).prompt_text ).prompt_text

View File

@ -24,8 +24,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessingMixin,
PromptReplacement) PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig) AriaVisionConfig)
@ -444,54 +445,33 @@ def build_mm_projector(config: PretrainedConfig):
) )
class AriaMultiModalProcessor(BaseMultiModalProcessor): class AriaProcessingMixin(ProcessingMixin):
def _get_hf_config(self):
return self.ctx.get_hf_config()
def _get_vision_config(self) -> AriaVisionConfig:
return self._get_hf_config().vision_config
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
class AriaProfilingInfo(AriaProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values())
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self._get_num_image_tokens()}
def _get_mm_fields_config( def get_dummy_processor_inputs(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_mask=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config()
image_token_id = hf_config.image_token_index
num_image_tokens = self._get_num_image_tokens()
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * num_image_tokens,
)
]
def _get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config() vision_config = self._get_vision_config()
vision_config: AriaVisionConfig = hf_config.vision_config
max_image_size = vision_config.image_size max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
@ -512,6 +492,41 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
) )
class AriaMultiModalProcessor(AriaProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return AriaProfilingInfo(self.ctx)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_mask=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
image_token_id = hf_config.image_token_index
num_image_tokens = self._get_num_image_tokens()
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * num_image_tokens,
)
]
@MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(AriaMultiModalProcessor)
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
""" """

View File

@ -4,8 +4,8 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import (BatchFeature, Blip2Config, Blip2Processor, from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
Blip2QFormerConfig, apply_chunking_to_forward) apply_chunking_to_forward)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
@ -18,8 +18,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessingMixin,
PromptReplacement) PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel from .blip import BlipVisionModel
@ -396,20 +397,52 @@ class Blip2QFormerModel(nn.Module):
return sequence_output return sequence_output
class Blip2MultiModalProcessor(BaseMultiModalProcessor): class Blip2ProcessingMixin(ProcessingMixin):
def _get_hf_config(self):
return self.ctx.get_hf_config(Blip2Config)
def _get_num_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return hf_config.num_query_tokens
class Blip2ProfilingInfo(Blip2ProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1} return {"image": 1}
def _get_num_image_tokens(self) -> int:
hf_config = self.ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> Blip2Processor: def get_dummy_processor_inputs(
return self.ctx.get_hf_processor(Blip2Processor) self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self._get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class Blip2MultiModalProcessor(Blip2ProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Blip2ProfilingInfo(self.ctx)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
@ -427,13 +460,13 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
max_image_tokens = self._get_num_image_tokens() num_image_tokens = self._get_num_image_tokens()
return [ return [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target="</s>", target="</s>",
replacement="<image>" * max_image_tokens + "</s>", replacement="<image>" * num_image_tokens + "</s>",
) )
] ]
@ -457,29 +490,6 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
return result return result
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):

View File

@ -31,8 +31,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessingMixin,
PromptReplacement) PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
@ -48,54 +49,33 @@ class ChameleonImagePixelInputs(TypedDict):
"""Shape: `(batch_size * num_images, num_channels, height, width)`""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
class ChameleonMultiModalProcessor(BaseMultiModalProcessor): class ChameleonProcessingMixin(ProcessingMixin):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def _get_hf_config(self):
return {"image": 1} return self.ctx.get_hf_config(ChameleonConfig)
def _get_hf_processor(self):
return self.ctx.get_hf_processor(ChameleonProcessor)
def _get_num_image_tokens(self) -> int: def _get_num_image_tokens(self) -> int:
processor = self._get_hf_processor() processor = self._get_hf_processor()
return processor.image_seq_length return processor.image_seq_length
class ChameleonProfilingInfo(ChameleonProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> ChameleonProcessor: def get_dummy_processor_inputs(
return self.ctx.get_hf_processor(ChameleonProcessor)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor()
return [
PromptReplacement(
modality="image",
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * self._get_num_image_tokens(),
processor.image_end_token,
]),
)
]
def _get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
config = self.ctx.get_hf_config(ChameleonConfig) config = self._get_hf_config()
width = height = config.vq_config.resolution width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
@ -112,6 +92,40 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
mm_data=mm_data, mm_data=mm_data,
) )
class ChameleonMultiModalProcessor(ChameleonProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return ChameleonProfilingInfo(self.ctx)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
processor = self._get_hf_processor(**hf_processor_mm_kwargs)
return [
PromptReplacement(
modality="image",
target="<image>",
replacement="".join([
processor.image_start_token,
processor.image_token * self._get_num_image_tokens(),
processor.image_end_token,
]),
)
]
def apply( def apply(
self, self,
prompt_text: str, prompt_text: str,

View File

@ -35,8 +35,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize from vllm.multimodal.parse import ImageProcessorItems, ImageSize
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessingMixin,
PromptReplacement) PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
@ -63,18 +64,16 @@ class FuyuImagePatchInputs(TypedDict):
""" """
class FuyuMultiModalProcessor(BaseMultiModalProcessor): class FuyuProcessingMixin(ProcessingMixin):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def _get_hf_config(self):
return {"image": 1} return self.ctx.get_hf_config(FuyuConfig)
def _get_image_target_size(self) -> ImageSize: def _get_hf_processor(self):
processor = self._get_hf_processor() return self.ctx.get_hf_processor(FuyuProcessor)
image_processor: FuyuImageProcessor = processor.image_processor
target_size = image_processor.size def _get_image_processor(self) -> FuyuImageProcessor:
return ImageSize(width=target_size["width"], return self._get_hf_processor().image_processor
height=target_size["height"])
def _get_image_feature_grid_size( def _get_image_feature_grid_size(
self, self,
@ -82,7 +81,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
image_width: int, image_width: int,
image_height: int, image_height: int,
) -> tuple[int, int]: ) -> tuple[int, int]:
target_width, target_height = self._get_image_target_size() image_processor = self._get_image_processor()
target_width = image_processor.size["width"]
target_height = image_processor.size["height"]
if not (image_width <= target_width and image_height <= target_height): if not (image_width <= target_width and image_height <= target_height):
height_scale_factor = target_height / image_height height_scale_factor = target_height / image_height
@ -96,8 +97,14 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
nrows = math.ceil(image_height / 30) nrows = math.ceil(image_height / 30)
return ncols, nrows return ncols, nrows
class FuyuProfilingInfo(FuyuProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_target_size() target_width, target_height = self._get_image_size_with_most_features()
max_ncols, max_nrows = self._get_image_feature_grid_size( max_ncols, max_nrows = self._get_image_feature_grid_size(
image_width=target_width, image_width=target_width,
@ -107,8 +114,36 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
return {"image": max_image_tokens} return {"image": max_image_tokens}
def _get_hf_processor(self) -> FuyuProcessor: def _get_image_size_with_most_features(self) -> ImageSize:
return self.ctx.get_hf_processor(FuyuProcessor) image_processor = self._get_image_processor()
return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"])
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = self._get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class FuyuMultiModalProcessor(FuyuProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return FuyuProfilingInfo(self.ctx)
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -161,7 +196,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(FuyuConfig) hf_config = self._get_hf_config()
bos_token_id = hf_config.bos_token_id bos_token_id = hf_config.bos_token_id
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
@ -208,26 +243,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
return result return result
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = self._get_image_target_size()
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):

View File

@ -1,4 +1,4 @@
from abc import abstractmethod from abc import ABC, abstractmethod
from functools import cached_property from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional, from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union) Protocol, Set, Tuple, TypedDict, Union)
@ -13,6 +13,7 @@ from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
@ -25,9 +26,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize) ImageSize)
from vllm.multimodal.processing import (InputProcessingContext, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingCache, MultiModalDataItems, ProcessingCache,
ProcessorInputs, PromptReplacement) ProcessingMixin, PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
@ -37,7 +39,7 @@ from .pixtral import (PixtralHFVisionModel,
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import BaseVisionLanguageMultiModalProcessor from .vision import get_vision_encoder_info
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
@ -94,30 +96,42 @@ class LlavaMultiModalProjector(nn.Module):
class LlavaLikeConfig(Protocol): class LlavaLikeConfig(Protocol):
vision_config: Final[PretrainedConfig] vision_config: Final[PretrainedConfig]
image_token_index: Final[int]
vision_feature_select_strategy: Final[str] vision_feature_select_strategy: Final[str]
vision_feature_layer: Final[Union[int, List[int]]] vision_feature_layer: Final[Union[int, list[int]]]
class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor): class LlavaLikeProcessor(Protocol):
image_token: Final[str]
class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
def _get_hf_config(self) -> LlavaLikeConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config())
@abstractmethod @abstractmethod
def _get_hf_config(self) -> LlavaLikeConfig: def _get_hf_processor(self) -> LlavaLikeProcessor:
raise NotImplementedError raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def _get_num_image_tokens(
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, *,
hf_processor_mm_kwargs: Mapping[str, object], image_width: int,
) -> Mapping[str, MultiModalFieldConfig]: image_height: int,
return dict( ) -> int:
pixel_values=MultiModalFieldConfig.batched("image"), hf_config = self._get_hf_config()
image_embeds=MultiModalFieldConfig.batched("image"), vision_encoder_info = self._get_vision_encoder_info()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
) )
def _apply_feature_select_strategy( def _apply_feature_select_strategy(
@ -133,31 +147,38 @@ class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
msg = f"Unexpected feature select strategy: {strategy!r}" msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _get_max_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return self._apply_feature_select_strategy( class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_max_image_tokens(), def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_image_tokens(self) -> int:
target_width, target_height = self._get_image_size_with_most_features()
return self._get_num_image_tokens(
image_width=target_width,
image_height=target_height,
) )
def _get_dummy_image_size(self) -> ImageSize: def get_dummy_processor_inputs(
image_size = self._vision_encoder_info.get_image_size()
return ImageSize(image_size, image_size)
@abstractmethod
def _get_image_token(self) -> str:
raise NotImplementedError
def _get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_token = self._get_image_token() processor = self._get_hf_processor()
target_width, target_height = self._get_dummy_image_size() image_token = processor.image_token
target_width, target_height = self._get_image_size_with_most_features()
mm_data = { mm_data = {
"image": "image":
@ -172,32 +193,32 @@ class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
) )
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor): class LlavaProcessingMixin(BaseLlavaProcessingMixin):
def _get_hf_config(self) -> LlavaConfig: def _get_hf_processor(self):
return self.ctx.get_hf_config(LlavaConfig)
def _get_hf_processor(self) -> LlavaProcessor:
return self.ctx.get_hf_processor(LlavaProcessor) return self.ctx.get_hf_processor(LlavaProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _get_num_image_tokens( class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo):
pass
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
BaseMultiModalProcessor):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_profiling_info(self) -> BaseProfilingInfo:
raise NotImplementedError
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
self, self,
*, hf_inputs: BatchFeature,
image_width: int, hf_processor_mm_kwargs: Mapping[str, object],
image_height: int, ) -> Mapping[str, MultiModalFieldConfig]:
) -> int: raise NotImplementedError
hf_config = self._get_hf_config()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
@ -232,16 +253,37 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
] ]
class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor): class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
def _get_hf_config(self) -> LlavaConfig: def _get_profiling_info(self) -> BaseProfilingInfo:
return self.ctx.get_hf_config(LlavaConfig) return LlavaProfilingInfo(self.ctx)
def _get_hf_processor(self) -> PixtralProcessor: def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin):
def _get_hf_processor(self):
return self.ctx.get_hf_processor(PixtralProcessor) return self.ctx.get_hf_processor(PixtralProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo):
pass
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return PixtralHFProfilingInfo(self.ctx)
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -270,6 +312,16 @@ class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
return processed_outputs return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
@ -316,7 +368,7 @@ def _build_llava_or_pixtral_hf_processor(
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True, enable_sanity_checks: bool = True,
) -> BaseLlavaMultiModalProcessor: ) -> BaseMultiModalProcessor:
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
if isinstance(hf_config.vision_config, PixtralVisionConfig): if isinstance(hf_config.vision_config, PixtralVisionConfig):
@ -663,16 +715,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class MantisMultiModalProcessor(LlavaMultiModalProcessor): class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)
def apply( def apply(
self, self,
prompt_text: str, prompt_text: str,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
hf_config = self.ctx.get_hf_config(LlavaConfig) hf_config = self._get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
# Assume that it doesn't depend on the image size # Assume that it doesn't depend on the image size

View File

@ -1,6 +1,6 @@
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Final, Iterable, List, Literal, Mapping, Optional,
TypedDict, Union) Protocol, Set, Tuple, TypedDict, Union)
import numpy as np import numpy as np
import torch import torch
@ -17,12 +17,14 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize from vllm.multimodal.parse import ImageSize
from vllm.multimodal.profiling import BaseProfilingInfo
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import (LlavaMultiModalProcessor, LlavaMultiModalProjector, from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin,
init_vision_tower_for_llava) BaseLlavaProfilingInfo, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model, maybe_prefix) init_vllm_registered_model, maybe_prefix)
@ -60,36 +62,18 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageEmbeddingInputs] LlavaNextImageEmbeddingInputs]
class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
image_grid_pinpoints: Final[list[list[int]]]
def _get_hf_config(self) -> LlavaNextConfig:
class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
def _get_hf_config(self) -> LlavaNextLikeConfig:
return self.ctx.get_hf_config(LlavaNextConfig) return self.ctx.get_hf_config(LlavaNextConfig)
def _get_hf_processor(self) -> LlavaNextProcessor: def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextProcessor) return self.ctx.get_hf_processor(LlavaNextProcessor)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _get_max_image_tokens(self) -> int:
largest_feature_size, _ = self._get_pinpoint_with_most_features()
return largest_feature_size
def _get_dummy_image_size(self) -> ImageSize:
_, pinpoint = self._get_pinpoint_with_most_features()
return pinpoint
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 # Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def _get_num_image_tokens( def _get_num_image_tokens(
self, self,
@ -98,7 +82,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
image_height: int, image_height: int,
) -> int: ) -> int:
hf_config = self._get_hf_config() hf_config = self._get_hf_config()
vision_encoder_info = self._vision_encoder_info vision_encoder_info = self._get_vision_encoder_info()
base_feature_size = self._apply_feature_select_strategy( base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy, hf_config.vision_feature_select_strategy,
@ -140,7 +124,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
current_height = npatches * num_patch_height current_height = npatches * num_patch_height
current_width = npatches * num_patch_width current_width = npatches * num_patch_width
# NOTE: HF resizes based on float32 # NOTE: Use float32 to remain consistent with HF output
original_aspect_ratio = np.array(original_width / original_height, original_aspect_ratio = np.array(original_width / original_height,
dtype=np.float32) dtype=np.float32)
current_aspect_ratio = np.array(current_width / current_height, current_aspect_ratio = np.array(current_width / current_height,
@ -164,11 +148,10 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
return (unpadded_features, newline_features) return (unpadded_features, newline_features)
def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:
""" class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
Get the grid pinpoint with the most features and
the corresponding feature size. def _get_image_size_with_most_features(self) -> ImageSize:
"""
hf_config = self._get_hf_config() hf_config = self._get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None largest_feature_size, largest_feature_pinpoint = 0, None
@ -183,7 +166,25 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
if largest_feature_size == 0 or largest_feature_pinpoint is None: if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!") raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_size, largest_feature_pinpoint return largest_feature_pinpoint
class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
BaseLlavaMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextProfilingInfo(self.ctx)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)

View File

@ -15,11 +15,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, NestedTensors)
VideoEmbeddingItems, VideoProcessorItems) from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems,
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement) PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
@ -28,7 +31,7 @@ from .llava import init_vision_tower_for_llava
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import BaseVisionLanguageMultiModalProcessor from .vision import get_vision_encoder_info
class LlavaNextVideoPixelInputs(TypedDict): class LlavaNextVideoPixelInputs(TypedDict):
@ -44,30 +47,17 @@ class LlavaNextVideoPixelInputs(TypedDict):
""" """
class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor): class LlavaNextVideoProcessingMixin(ProcessingMixin):
def _get_hf_config(self) -> LlavaNextVideoConfig: def _get_hf_config(self):
return self.ctx.get_hf_config(LlavaNextVideoConfig) return self.ctx.get_hf_config(LlavaNextVideoConfig)
def _get_hf_processor(self) -> LlavaNextVideoProcessor: def _get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config())
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextVideoProcessor) return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
return {"video": max_video_tokens}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
def _get_num_frame_tokens( def _get_num_frame_tokens(
self, self,
*, *,
@ -77,7 +67,8 @@ class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
hf_config = self._get_hf_config() hf_config = self._get_hf_config()
spatial_pool_stride = hf_config.spatial_pool_stride spatial_pool_stride = hf_config.spatial_pool_stride
patch_grid_length = self._vision_encoder_info.get_patch_grid_length() vision_encoder_info = self._get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length return pooled_grid_length * pooled_grid_length
@ -96,18 +87,43 @@ class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
return num_frame_tokens * num_frames return num_frame_tokens * num_frames
def _get_max_video_tokens(self, num_frames: int) -> int:
return self._get_num_video_tokens(image_width=999999, class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
image_height=999999, BaseProfilingInfo):
num_frames=num_frames)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_video_tokens = self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
)
return {"video": max_video_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
num_frames = 0 num_frames = 0
while True: while True:
next_num_frames = num_frames + 1 next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
)
if self._get_max_video_tokens(next_num_frames) > max_tokens: if next_max_tokens > max_tokens:
break break
num_frames = next_num_frames num_frames = next_num_frames
@ -122,12 +138,45 @@ class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
return max(max_total_frames // max(max_videos, 1), 1) return max(max_total_frames // max(max_videos, 1), 1)
def _get_dummy_image_size(self) -> ImageSize: def get_dummy_processor_inputs(
image_size = self._vision_encoder_info.get_image_size() self,
return ImageSize(image_size, image_size) seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0)
def _get_video_token(self) -> str: processor = self._get_hf_processor()
return self._get_hf_processor().video_token video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
mm_data = {
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=video_token * num_videos,
mm_data=mm_data,
)
class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextVideoProfilingInfo(self.ctx)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
@ -162,36 +211,11 @@ class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
), ),
] ]
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0)
video_token = self._get_video_token()
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=video_token * num_videos,
mm_data=mm_data,
)
# adopted from transformers modeling_llava_next_video.py # adopted from transformers modeling_llava_next_video.py
class LlavaNextVideoPooler(nn.Module): class LlavaNextVideoPooler(nn.Module):
def __init__(self, config): def __init__(self, config: LlavaNextVideoConfig):
super().__init__() super().__init__()
mode = config.spatial_pool_mode mode = config.spatial_pool_mode
@ -209,7 +233,7 @@ class LlavaNextVideoPooler(nn.Module):
raise ValueError( raise ValueError(
f"Unknown pooling mode: {mode}. Expected [`average`, `max`]") f"Unknown pooling mode: {mode}. Expected [`average`, `max`]")
def forward(self, image_features): def forward(self, image_features: torch.Tensor):
ori_width = int( ori_width = int(
math.sqrt(image_features.shape[1] * self.image_size // math.sqrt(image_features.shape[1] * self.image_size //
self.image_size)) self.image_size))

View File

@ -1,7 +1,7 @@
import math import math
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Final, Iterable, List, Literal, Mapping, Optional,
TypedDict, Union) Protocol, Set, Tuple, TypedDict, Union)
import numpy as np import numpy as np
import torch import torch
@ -21,15 +21,16 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems, from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
VideoProcessorItems) VideoProcessorItems)
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs, from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement
PromptReplacement) from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava
from .llava_next import LlavaNextMultiModalProcessor from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor,
LlavaNextProcessingMixin)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
@ -82,40 +83,18 @@ LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
LlavaOnevisionVideoPixelInputs] LlavaOnevisionVideoPixelInputs]
class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor): class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
video_token_index: Final[int]
def _get_hf_config(self) -> LlavaOnevisionConfig:
class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
def _get_hf_config(self) -> LlavaOnevisionLikeConfig:
return self.ctx.get_hf_config(LlavaOnevisionConfig) return self.ctx.get_hf_config(LlavaOnevisionConfig)
def _get_hf_processor(self) -> LlavaOnevisionProcessor: def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaOnevisionProcessor) return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
max_image_tokens = self._get_max_image_tokens()
num_frames = self._get_dummy_num_frames(seq_len)
max_video_tokens = self._get_max_video_tokens(num_frames)
return {
"image": max_image_tokens,
"video": max_video_tokens,
}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.batched("video"),
)
def _get_num_unpadded_features( def _get_num_unpadded_features(
self, self,
*, *,
@ -128,7 +107,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
current_height = npatches * num_patch_height current_height = npatches * num_patch_height
current_width = npatches * num_patch_width current_width = npatches * num_patch_width
# NOTE: HF resizes based on float32 # NOTE: Use float32 to remain consistent with HF output
original_aspect_ratio = np.array(original_width / original_height, original_aspect_ratio = np.array(original_width / original_height,
dtype=np.float32) dtype=np.float32)
current_aspect_ratio = np.array(current_width / current_height, current_aspect_ratio = np.array(current_width / current_height,
@ -167,7 +146,8 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
hf_config = self._get_hf_config() hf_config = self._get_hf_config()
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2) spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
patch_grid_length = self._vision_encoder_info.get_patch_grid_length() vision_encoder_info = self._get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length return pooled_grid_length * pooled_grid_length
@ -186,18 +166,33 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
return num_frame_tokens * num_frames + 1 # Newline token return num_frame_tokens * num_frames + 1 # Newline token
def _get_max_video_tokens(self, num_frames: int) -> int:
return self._get_num_video_tokens(image_width=999999, class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
image_height=999999, BaseLlavaProfilingInfo):
num_frames=num_frames)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self._get_max_image_tokens(),
"video": self._get_max_video_tokens(seq_len),
}
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
num_frames = 0 num_frames = 0
while True: while True:
next_num_frames = num_frames + 1 next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
)
if self._get_max_video_tokens(next_num_frames) > max_tokens: if next_max_tokens > max_tokens:
break break
num_frames = next_num_frames num_frames = next_num_frames
@ -215,8 +210,65 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
return max(max_total_frames // max(max_videos, 1), 1) return max(max_total_frames // max(max_videos, 1), 1)
def _get_video_token(self) -> str: def _get_max_video_tokens(self, seq_len: int) -> int:
return self._get_hf_processor().video_token target_width, target_height = self._get_image_size_with_most_features()
return self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
)
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
processor = self._get_hf_processor()
image_token = processor.image_token
video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
LlavaNextMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaOnevisionProfilingInfo(self.ctx)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.batched("video"),
)
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -235,7 +287,8 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
) )
video_token = self._get_video_token() processor = self._get_hf_processor()
video_token = processor.video_token
# LLaVA-OneVision processor doesn't support multiple videos # LLaVA-OneVision processor doesn't support multiple videos
# with different sizes when converting back to tensors # with different sizes when converting back to tensors
@ -303,37 +356,6 @@ class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
), ),
] ]
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_token = self._get_image_token()
video_token = self._get_video_token()
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
class LlavaOnevisionMultiModalProjector(nn.Module): class LlavaOnevisionMultiModalProjector(nn.Module):

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -28,22 +28,23 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageEmbeddingItems, ImageProcessorItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessingMixin,
PromptReplacement, PromptReplacement,
_BoundPromptReplacement, _BoundPromptReplacement,
_PlaceholderInfo) _PlaceholderInfo)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import dummy_image_for_clip from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
@ -54,10 +55,6 @@ logger = init_logger(__name__)
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044 _IMAGE_TOKEN_ID = 32044
# Result in the max possible feature size (h:w = 16:1)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
MAX_IMAGE_FEATURE_SIZE_WIDTH = 50
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
hidden_act="quick_gelu", hidden_act="quick_gelu",
hidden_size=1024, hidden_size=1024,
@ -305,10 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline return image_features_hd_newline
class Phi3VMultiModalProcessor(BaseMultiModalProcessor): class Phi3VProcessingMixin(ProcessingMixin):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def _get_hf_processor(
return {"image": None} self,
*,
num_crops: Optional[int] = None,
) -> ProcessorMixin:
if num_crops is not None:
return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor()
def _get_num_image_tokens( def _get_num_image_tokens(
self, self,
@ -323,23 +327,55 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
height=image_height, height=image_height,
) )
class Phi3VProfilingInfo(Phi3VProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_image_tokens = self._get_num_image_tokens( max_image_tokens = self._get_num_image_tokens(
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_width=target_width,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, image_height=target_height,
) )
return {"image": max_image_tokens} return {"image": max_image_tokens}
def _get_hf_processor( def _get_image_size_with_most_features(self) -> ImageSize:
self, # Result in the max possible feature size (h:w = 16:1)
*, return ImageSize(height=8000, width=50)
num_crops: Optional[int] = None,
) -> ProcessorMixin:
if num_crops is not None:
return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor() def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
target_width, target_height = self._get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=mm_data,
)
class Phi3VMultiModalProcessor(Phi3VProcessingMixin, BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Phi3VProfilingInfo(self.ctx)
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -377,10 +413,10 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore image_tokens: list[str] = hf_processor.img_tokens # type: ignore
tokenizer = self._get_tokenizer() tokenizer = self._get_tokenizer()
@ -442,28 +478,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return token_ids, text, placeholders return token_ids, text, placeholders
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
)
def apply( def apply(
self, self,
prompt_text: str, prompt_text: str,

View File

@ -20,8 +20,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
Union) TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -40,8 +40,9 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessingMixin,
PromptReplacement) PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
@ -79,17 +80,10 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return feat_lengths, output_lengths return feat_lengths, output_lengths
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): class Qwen2AudioProcessingMixin(ProcessingMixin):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def _get_hf_config(self):
return {"audio": None} return self.ctx.get_hf_config(Qwen2AudioConfig)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
def _get_hf_processor( def _get_hf_processor(
self, self,
@ -99,8 +93,57 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
) -> Qwen2AudioProcessor: ) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor) return self.ctx.get_hf_processor(Qwen2AudioProcessor)
def _get_feature_extractor(self) -> WhisperFeatureExtractor: def _get_feature_extractor(
return self._get_hf_processor().feature_extractor # type: ignore self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
class Qwen2AudioProfilingInfo(Qwen2AudioProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
hf_config = self._get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1
return {"audio": max_output_lengths}
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|AUDIO|>" * num_audios,
mm_data=mm_data,
)
class Qwen2AudioMultiModalProcessor(Qwen2AudioProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Qwen2AudioProfilingInfo(self.ctx)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
@ -110,7 +153,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
self, self,
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, Any],
) -> BatchFeature: ) -> BatchFeature:
mm_data = dict(mm_data) mm_data = dict(mm_data)
audios = mm_data.pop("audios", []) audios = mm_data.pop("audios", [])
@ -118,7 +161,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
if audios: if audios:
mm_data["audios"] = audios mm_data["audios"] = audios
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor(**mm_kwargs)
mm_kwargs = dict( mm_kwargs = dict(
**mm_kwargs, **mm_kwargs,
sampling_rate=feature_extractor.sampling_rate, sampling_rate=feature_extractor.sampling_rate,
@ -151,7 +194,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) hf_config = self._get_hf_config()
placeholder = hf_config.audio_token_index placeholder = hf_config.audio_token_index
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask") feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
@ -191,27 +234,6 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
# tokens than the number of audio items) # tokens than the number of audio items)
return True return True
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|AUDIO|>" * num_audios,
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Qwen2AudioMultiModalProcessor)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,

View File

@ -59,8 +59,9 @@ from vllm.multimodal.inputs import (ImageItem, ModalityData,
from vllm.multimodal.parse import (ImageSize, ModalityDataItems, from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessingMixin,
PromptReplacement) PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
@ -708,10 +709,44 @@ class Qwen2MultiModalDataParser(MultiModalDataParser):
return super()._parse_video_data(data) return super()._parse_video_data(data)
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): class Qwen2VLProcessingMixin(ProcessingMixin):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def _get_hf_config(self):
return {"image": None, "video": None} return self.ctx.get_hf_config(Qwen2VLConfig)
def _get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> Qwen2VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
if min_pixels:
image_processor.min_pixels = min_pixels
if max_pixels:
image_processor.max_pixels = max_pixels
if max_pixels or min_pixels:
image_processor.size = {
"min_pixels": image_processor.min_pixels,
"max_pixels": image_processor.max_pixels,
}
return hf_processor
def _get_image_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
):
hf_processor = self._get_hf_processor(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def _get_vision_info( def _get_vision_info(
self, self,
@ -721,14 +756,13 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
num_frames: int = 1, num_frames: int = 1,
do_resize: bool = True, do_resize: bool = True,
) -> tuple[ImageSize, int]: ) -> tuple[ImageSize, int]:
hf_config = self.ctx.get_hf_config(Qwen2VLConfig) hf_config = self._get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size temporal_patch_size = vision_config.temporal_patch_size
hf_processor = self._get_hf_processor() image_processor = self._get_image_processor()
image_processor = self._get_image_processor(hf_processor)
if do_resize: if do_resize:
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
@ -753,7 +787,45 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
return preprocessed_size, num_vision_tokens return preprocessed_size, num_vision_tokens
def _get_dummy_image_size(self) -> ImageSize: def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
)
return num_image_tokens
def _get_num_video_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int,
) -> int:
_, num_video_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=num_frames,
)
return num_video_tokens
class Qwen2VLProfilingInfo(Qwen2VLProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self._get_max_image_tokens(),
"video": self._get_max_video_tokens(seq_len),
}
def _get_image_size_with_most_features(self) -> ImageSize:
max_image_size, _ = self._get_vision_info( max_image_size, _ = self._get_vision_info(
image_width=9999999, image_width=9999999,
image_height=9999999, image_height=9999999,
@ -761,27 +833,27 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
return max_image_size return max_image_size
def _get_max_image_tokens(self) -> int: def _get_max_image_tokens(self) -> int:
_, max_image_tokens = self._get_vision_info( target_width, target_height = self._get_image_size_with_most_features()
image_width=9999999,
image_height=9999999,
)
return max_image_tokens
def _get_max_video_tokens(self, num_frames: int) -> int: return self._get_num_image_tokens(
_, max_video_tokens = self._get_vision_info( image_width=target_width,
image_width=9999999, image_height=target_height,
image_height=9999999,
num_frames=num_frames,
) )
return max_video_tokens
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
num_frames = 0 num_frames = 0
while True: while True:
next_num_frames = num_frames + 1 next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
)
if self._get_max_video_tokens(next_num_frames) > max_tokens: if next_max_tokens > max_tokens:
break break
num_frames = next_num_frames num_frames = next_num_frames
@ -797,56 +869,73 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens) max_image_tokens)
return max(max_total_frames // max(max_videos, 1), 1) num_frames = max(max_total_frames // max(max_videos, 1), 1)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: # Temporary workaround for https://github.com/huggingface/transformers/issues/35412
max_image_tokens = self._get_max_image_tokens() if num_frames > 1 and num_frames % 2 == 1:
num_frames += 1
num_frames = self._get_dummy_num_frames(seq_len) return num_frames
max_video_tokens = self._get_max_video_tokens(num_frames)
return { def _get_max_video_tokens(self, seq_len: int) -> int:
"image": max_image_tokens, target_width, target_height = self._get_image_size_with_most_features()
"video": max_video_tokens,
return self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
)
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
} }
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
class Qwen2VLMultiModalProcessor(Qwen2VLProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return Qwen2VLProfilingInfo(self.ctx)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser() return Qwen2MultiModalDataParser()
def _get_image_processor(self, hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def _get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> Qwen2VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = self._get_image_processor(hf_processor)
if min_pixels:
image_processor.min_pixels = min_pixels
if max_pixels:
image_processor.max_pixels = max_pixels
if max_pixels or min_pixels:
image_processor.size = {
"min_pixels": image_processor.min_pixels,
"max_pixels": image_processor.max_pixels,
}
return hf_processor
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self._get_image_processor(hf_processor) image_processor = self._get_image_processor(**hf_processor_mm_kwargs)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered # image_token and video_token registered
@ -901,38 +990,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
video_grid_thw=MultiModalFieldConfig.batched("video"), video_grid_thw=MultiModalFieldConfig.batched("video"),
) )
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor) @MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,

View File

@ -3,8 +3,8 @@
import math import math
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
TypedDict, Union) Tuple, TypedDict, Union)
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
@ -26,8 +26,9 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessingMixin,
PromptReplacement) PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
@ -55,7 +56,30 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
UltravoxAudioEmbeddingInputs] UltravoxAudioEmbeddingInputs]
class UltravoxMultiModalProcessor(BaseMultiModalProcessor): class UltravoxProcessingMixin(ProcessingMixin):
def _get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> ProcessorMixin:
return self.ctx.get_hf_processor()
def _get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor(sampling_rate=sampling_rate)
audio_processor = hf_processor.audio_processor # type: ignore
feature_extractor = audio_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
class UltravoxProfilingInfo(UltravoxProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None} return {"audio": None}
@ -67,17 +91,33 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
return {"audio": max_audio_tokens} return {"audio": max_audio_tokens}
def _get_hf_processor( def get_dummy_processor_inputs(
self, self,
*, seq_len: int,
# Ignored in initialization mm_counts: Mapping[str, int],
sampling_rate: Optional[int] = None, ) -> ProcessorInputs:
) -> ProcessorMixin: feature_extractor = self._get_feature_extractor()
return self.ctx.get_hf_processor()
def _get_feature_extractor(self) -> WhisperFeatureExtractor: sampling_rate = feature_extractor.sampling_rate
hf_processor = self._get_hf_processor() audio_len = feature_extractor.chunk_length * sampling_rate
return hf_processor.audio_processor.feature_extractor # type: ignore num_audios = mm_counts.get("audio", 0)
mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|audio|>" * num_audios,
mm_data=mm_data,
)
class UltravoxMultiModalProcessor(UltravoxProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return UltravoxProfilingInfo(self.ctx)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
@ -155,10 +195,10 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor(**hf_processor_mm_kwargs)
placeholder = hf_processor.audio_token_replacement # type: ignore placeholder = hf_processor.audio_token_replacement # type: ignore
def get_replacement_ultravox(item_idx: int): def get_replacement_ultravox(item_idx: int):
@ -173,27 +213,6 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
) )
] ]
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|audio|>" * num_audios,
mm_data=mm_data,
)
class StackAudioFrames(nn.Module): class StackAudioFrames(nn.Module):
""" """

View File

@ -1,12 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Final, Generic, Optional, Protocol, TypeVar from typing import Final, Generic, Protocol, TypeVar
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ProcessingCache)
_C = TypeVar("_C", bound=PretrainedConfig) _C = TypeVar("_C", bound=PretrainedConfig)
@ -43,12 +39,18 @@ class VisionEncoderInfo(ABC, Generic[_C]):
raise NotImplementedError raise NotImplementedError
def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo: class VisionLanguageConfig(Protocol):
vision_config: Final[PretrainedConfig]
def get_vision_encoder_info(
hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
# Avoid circular imports # Avoid circular imports
from .clip import CLIPEncoderInfo, CLIPVisionConfig from .clip import CLIPEncoderInfo, CLIPVisionConfig
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
from .siglip import SiglipEncoderInfo, SiglipVisionConfig from .siglip import SiglipEncoderInfo, SiglipVisionConfig
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig): if isinstance(vision_config, CLIPVisionConfig):
return CLIPEncoderInfo(vision_config) return CLIPEncoderInfo(vision_config)
if isinstance(vision_config, PixtralVisionConfig): if isinstance(vision_config, PixtralVisionConfig):
@ -58,26 +60,3 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
msg = f"Unsupported vision config: {type(vision_config)}" msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
class VisionLanguageConfig(Protocol):
vision_config: Final[PretrainedConfig]
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod
def _get_hf_config(self) -> VisionLanguageConfig:
raise NotImplementedError

View File

@ -8,11 +8,10 @@ from functools import lru_cache
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
import numpy as np import numpy as np
import numpy.typing as npt
import torch import torch
from blake3 import blake3 from blake3 import blake3
from PIL import Image from PIL import Image
from transformers import BatchFeature, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
@ -24,6 +23,7 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser from .parse import MultiModalDataItems, MultiModalDataParser
from .profiling import BaseProfilingInfo
logger = init_logger(__name__) logger = init_logger(__name__)
@ -466,14 +466,6 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
@dataclass
class ProcessorInputs:
"""Keyword arguments to :meth:`BaseMultiModalProcessor`."""
prompt_text: str
mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
class ProcessingCache: class ProcessingCache:
def __init__(self, capacity: int) -> None: def __init__(self, capacity: int) -> None:
@ -585,9 +577,33 @@ class ProcessingCache:
self._cache.put(cache_key, output_kwargs) self._cache.put(cache_key, output_kwargs)
class BaseMultiModalProcessor(ABC): class ProcessingMixin:
"""
Contains helper functions to perform processing.
Not to be confused with :class:`transformers.ProcessorMixin`.
"""
ctx: InputProcessingContext
def _get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer
def _get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config()
def _get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
"""
Subclasses can override this method to handle
specific kwargs from model config or user inputs.
"""
return self.ctx.get_hf_processor(**kwargs)
class BaseMultiModalProcessor(ProcessingMixin, ABC):
""" """
Abstract base class to process multi-modal inputs to be used in vLLM. Abstract base class to process multi-modal inputs to be used in vLLM.
Not to be confused with :class:`transformers.ProcessorMixin`.
""" """
def __init__(self, def __init__(self,
@ -601,6 +617,9 @@ class BaseMultiModalProcessor(ABC):
self.cache = cache self.cache = cache
self.enable_sanity_checks = enable_sanity_checks self.enable_sanity_checks = enable_sanity_checks
self.data_parser = self._get_data_parser()
self.profiling_info = self._get_profiling_info()
def __call__( def __call__(
self, self,
prompt: str, prompt: str,
@ -609,32 +628,9 @@ class BaseMultiModalProcessor(ABC):
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs) return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
""" """
Construct a data parser to preprocess multi-modal data items Construct a parser to preprocess multi-modal data items
before passing them to :meth:`_get_hf_mm_data`. before passing them to :meth:`_get_hf_mm_data`.
You can support additional modalities by creating a subclass You can support additional modalities by creating a subclass
@ -642,15 +638,12 @@ class BaseMultiModalProcessor(ABC):
""" """
return MultiModalDataParser() return MultiModalDataParser()
def _get_hf_processor(self) -> ProcessorMixin: def _get_profiling_info(self) -> BaseProfilingInfo:
""" """
Subclasses can add keyword arguments to this method to accept Get the profiling information to find the worst-case memory usage of
additional kwargs from model config or user inputs. the model.
""" """
return self.ctx.get_hf_processor() raise NotImplementedError
def _get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer
def _to_mm_items( def _to_mm_items(
self, self,
@ -660,8 +653,7 @@ class BaseMultiModalProcessor(ABC):
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems` Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
before passing them to :meth:`_get_hf_mm_data`. before passing them to :meth:`_get_hf_mm_data`.
""" """
parser = self._get_data_parser() mm_items = self.data_parser.parse_mm_data(mm_data)
mm_items = parser.parse_mm_data(mm_data)
mm_limits = self.ctx.get_mm_config().limit_per_prompt mm_limits = self.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items(): for modality, items in mm_items.items():
@ -799,7 +791,7 @@ class BaseMultiModalProcessor(ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding # Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text # multi-modal tokens to be in the prompt text
dummy_inputs = self._get_dummy_processor_inputs( dummy_inputs = self.profiling_info.get_dummy_processor_inputs(
self.ctx.model_config.max_model_len, self.ctx.model_config.max_model_len,
mm_missing_counts, mm_missing_counts,
) )
@ -1133,73 +1125,14 @@ class BaseMultiModalProcessor(ABC):
mm_placeholders=mm_placeholder_ranges, mm_placeholders=mm_placeholder_ranges,
) )
def _get_dummy_audios(
self,
*,
length: int,
num_audios: int,
) -> list[npt.NDArray]:
audio = np.zeros((length, ))
return [audio] * num_audios
def _get_dummy_images(
self,
*,
width: int,
height: int,
num_images: int,
) -> list[Image.Image]:
image = Image.new("RGB", (width, height), color=0)
return [image] * num_images
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[npt.NDArray]:
video = np.zeros((num_frames, width, height, 3))
return [video] * num_videos
@abstractmethod
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
"""
Build the multi-modal portion of the input which, after processing,
results in `mm_max_tokens` in :meth:`get_dummy_data`.
"""
raise NotImplementedError
def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]:
mm_limit_per_prompt = self.ctx.get_mm_config().limit_per_prompt
supported_mm_limits = self.get_supported_mm_limits()
mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1)
for modality in supported_mm_limits
}
for modality, supported_limit in supported_mm_limits.items():
limit = mm_limits[modality]
if supported_limit is not None and supported_limit < limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but this model only supports "
f"at most {supported_limit} {modality} items.")
return mm_limits
def _get_dummy_mm_inputs( def _get_dummy_mm_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
processor_inputs = self._get_dummy_processor_inputs(seq_len, mm_counts) profiling = self.profiling_info
processor_inputs = profiling.get_dummy_processor_inputs(
seq_len, mm_counts)
return self.apply( return self.apply(
prompt_text=processor_inputs.prompt_text, prompt_text=processor_inputs.prompt_text,
@ -1211,8 +1144,9 @@ class BaseMultiModalProcessor(ABC):
# Avoid circular import # Avoid circular import
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
mm_counts = self._get_and_validate_dummy_mm_counts() profiling = self.profiling_info
mm_max_tokens_per_item = self.get_mm_max_tokens_per_item(seq_len) mm_counts = profiling.get_mm_limits()
mm_max_tokens_per_item = profiling.get_mm_max_tokens_per_item(seq_len)
if mm_counts.keys() != mm_max_tokens_per_item.keys(): if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError( raise AssertionError(
"The keys returned by `get_supported_mm_limits`" "The keys returned by `get_supported_mm_limits`"

View File

@ -0,0 +1,121 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm.inputs import InputProcessingContext
from vllm.logger import init_logger
from .inputs import MultiModalDataDict
logger = init_logger(__name__)
@dataclass
class ProcessorInputs:
"""Keyword arguments to :meth:`BaseMultiModalProcessor`."""
prompt_text: str
mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
class BaseProfilingInfo(ABC):
"""
Abstract base class that provides the information necessary to profile
multi-modal models.
"""
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__()
self.ctx = ctx
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise NotImplementedError
@abstractmethod
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
"""
Build the multi-modal portion of the input which, after processing,
results in `mm_max_tokens` in :meth:`get_mm_max_tokens_per_item`.
"""
raise NotImplementedError
def _get_dummy_audios(
self,
*,
length: int,
num_audios: int,
) -> list[npt.NDArray]:
audio = np.zeros((length, ))
return [audio] * num_audios
def _get_dummy_images(
self,
*,
width: int,
height: int,
num_images: int,
) -> list[Image.Image]:
image = Image.new("RGB", (width, height), color=0)
return [image] * num_images
def _get_dummy_videos(
self,
*,
width: int,
height: int,
num_frames: int,
num_videos: int,
) -> list[npt.NDArray]:
video = np.zeros((num_frames, width, height, 3))
return [video] * num_videos
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt
supported_mm_limits = self.get_supported_mm_limits()
mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1)
for modality in supported_mm_limits
}
for modality, supported_limit in supported_mm_limits.items():
limit = mm_limits[modality]
if supported_limit is not None and supported_limit < limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
f"`--limit-mm-per-prompt`, but this model only supports "
f"at most {supported_limit} {modality} items.")
return mm_limits

View File

@ -224,7 +224,7 @@ class MultiModalRegistry:
tokenizer = cached_get_tokenizer(model_config.tokenizer) tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = self.create_processor(model_config, tokenizer) processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
return processor.get_mm_max_tokens_per_item(seq_len) return processor.profiling_info.get_mm_max_tokens_per_item(seq_len)
return { return {
key: plugin.get_max_multimodal_tokens(model_config) key: plugin.get_max_multimodal_tokens(model_config)