[Bugfix] Check dimensions of multimodal embeddings in V1 (#15816)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-01 00:01:35 +08:00 committed by GitHub
parent e5ef4fa99a
commit 09e974d483
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 98 additions and 37 deletions

View File

@ -68,7 +68,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompts = [f"Question: {question} Answer:" for question in questions]
engine_args = EngineArgs(
model="Salesforce/blip2-opt-2.7b",
model="Salesforce/blip2-opt-6.7b",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
@ -128,7 +128,8 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
engine_args = EngineArgs(
model="microsoft/Florence-2-large",
tokenizer="facebook/bart-large",
max_num_seqs=8,
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
@ -511,7 +512,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=16,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
@ -700,7 +701,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
# NOTE: Need L40 (or equivalent) to avoid OOM
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_model_len=6144,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

View File

@ -217,7 +217,7 @@ EMBEDDING_MODELS = { # type: ignore[var-annotated]
MULTIMODAL_MODELS = {
# [Decoder-only]
"Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
"Salesforce/blip2-opt-6.7b": PPTestSettings.fast(),
"facebook/chameleon-7b": PPTestSettings.fast(),
"adept/fuyu-8b": PPTestSettings.fast(),
"THUDM/glm-4v-9b": PPTestSettings.fast(),

View File

@ -34,8 +34,6 @@ REQUIRES_V0_MODELS = [
# V1 Test: no way to fall back for head_dim = 80
# https://github.com/vllm-project/vllm/issues/14524
"qwen_vl",
"h2ovl",
"blip2",
# V1 Test: not enough KV cache space in C1.
"fuyu",
]
@ -161,7 +159,8 @@ VLM_TEST_SETTINGS = {
marks=[large_gpu_mark(min_gb=64)],
),
"blip2": VLMTestInfo(
models=["Salesforce/blip2-opt-2.7b"],
# TODO: Change back to 2.7b once head_dim = 80 is supported
models=["Salesforce/blip2-opt-6.7b"],
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
img_idx_to_prompt=lambda idx: "",
@ -248,7 +247,8 @@ VLM_TEST_SETTINGS = {
"h2ovl": VLMTestInfo(
models = [
"h2oai/h2ovl-mississippi-800m",
"h2oai/h2ovl-mississippi-2b",
# TODO: Re-enable once head_dim = 80 is supported
# "h2oai/h2ovl-mississippi-2b",
],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501

View File

@ -259,7 +259,8 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
_MULTIMODAL_EXAMPLE_MODELS = {
# [Decoder-only]
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501

View File

@ -875,7 +875,8 @@ class Florence2MultiModalProcessor(
Florence2MultiModalProcessor,
info=Florence2ProcessingInfo,
dummy_inputs=Florence2DummyInputsBuilder)
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -39,7 +39,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
@ -66,10 +65,13 @@ class FuyuImagePatchInputs(TypedDict):
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
@ -322,16 +324,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
if image_patches is not None:
if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches)
embed_is_patch = flatten_bn(embed_is_patch)
return FuyuImagePatchInputs(
type="image_patches",
@ -351,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
assert self.vision_embed_tokens is not None
vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(
@ -358,13 +363,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
#return vision_embeddings
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["embed_is_patch"],
))
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,

View File

@ -613,7 +613,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> tuple[torch.Tensor, ...]:
) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
@ -625,7 +625,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
)
image_embeds = self.multi_modal_projector(image_features)
return image_embeds.split(num_patches.tolist())
return [
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
]
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

View File

@ -733,7 +733,10 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_attention_mask=pixel_attention_mask,
)
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
def _process_image_input(
self,
image_input: ImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@ -741,7 +744,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_features = self.model.connector(image_features)
num_patches = image_input["num_patches"]
return image_features.split(num_patches.tolist())
return [
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
]
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:

View File

@ -406,20 +406,21 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
h, w)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels)
return stacked_embeddings.view(b, num_frames,
*stacked_embeddings.shape[1:])
embeds = stacked_embeddings.view(b, num_frames,
*stacked_embeddings.shape[1:])
elif is_list_of(video_pixels, torch.Tensor):
frames_per_videos = [v.shape[0] for v in video_pixels]
stacked_pixels = torch.cat(video_pixels, dim=0)
stacked_embeddings = self._video_pixels_to_features(
self.vision_tower, stacked_pixels)
return torch.split(stacked_embeddings, frames_per_videos, dim=0)
embeds = torch.split(stacked_embeddings, frames_per_videos, dim=0)
else:
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
return [e.flatten(0, 1) for e in embeds]
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
video_input = self._parse_and_validate_video_input(**kwargs)

View File

@ -919,8 +919,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
image_features_flat = self.get_vision_hidden_states(image_input)
# Reconstruct the batch dimension
return image_features_flat.split(image_input["num_slices"].tolist())
num_slices = image_input["num_slices"]
return [
e.flatten(0, 1)
for e in image_features_flat.split(num_slices.tolist())
]
def _process_multimodal_inputs(self, modalities: dict):
# The result multimodal_embeddings is tuple of tensors, with each

View File

@ -204,7 +204,7 @@ def scatter_patch_features(
(e_is_patch.shape[0], patches_one.shape[-1]),
fill_value=torch.nan,
)
embed_one[e_is_patch] = patches_one.flatten(0, -2)
embed_one[e_is_patch] = patches_one
return embed_one
return tuple(

View File

@ -41,6 +41,8 @@ from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import sanity_check_mm_encoder_outputs
if TYPE_CHECKING:
import xgrammar as xgr
@ -867,6 +869,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
)
for output in curr_group_outputs:
encoder_outputs.append(output)
@ -1490,12 +1497,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
assert len(dummy_encoder_outputs) == max_num_mm_items, (
"Expected dimension 0 of encoder outputs to match the number "
f"of multimodal data items: {max_num_mm_items}, got "
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
"due to the 'get_multimodal_embeddings' method of the model "
"not implemented correctly.")
sanity_check_mm_encoder_outputs(
dummy_encoder_outputs,
expected_num_items=max_num_mm_items,
)
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

View File

@ -37,6 +37,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from .utils import sanity_check_mm_encoder_outputs
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@ -512,6 +514,11 @@ class TPUModelRunner:
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
)
for output in curr_group_outputs:
encoder_outputs.append(output)

29
vllm/v1/worker/utils.py Normal file
View File

@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
import torch
def sanity_check_mm_encoder_outputs(
mm_embeddings: object,
expected_num_items: int,
) -> None:
"""
Perform sanity checks for the result of
:meth:`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`.
"""
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
f"or a single 3D tensor, but got {type(mm_embeddings)} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert len(mm_embeddings) == expected_num_items, (
"Expected number of multimodal embeddings to match number of "
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")
assert all(e.ndim == 2 for e in mm_embeddings), (
"Expected multimodal embeddings to be a sequence of 2D tensors, "
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
"instead. This is most likely due to incorrect implementation "
"of the model's `get_multimodal_embeddings` method.")