[Bugfix] Check dimensions of multimodal embeddings in V1 (#15816)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e5ef4fa99a
commit
09e974d483
@ -68,7 +68,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
|
||||||
prompts = [f"Question: {question} Answer:" for question in questions]
|
prompts = [f"Question: {question} Answer:" for question in questions]
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model="Salesforce/blip2-opt-2.7b",
|
model="Salesforce/blip2-opt-6.7b",
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,7 +128,8 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model="microsoft/Florence-2-large",
|
model="microsoft/Florence-2-large",
|
||||||
tokenizer="facebook/bart-large",
|
tokenizer="facebook/bart-large",
|
||||||
max_num_seqs=8,
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
@ -511,7 +512,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=16,
|
max_num_seqs=2,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -700,7 +701,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
# NOTE: Need L40 (or equivalent) to avoid OOM
|
# NOTE: Need L40 (or equivalent) to avoid OOM
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
max_model_len=8192,
|
max_model_len=6144,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
@ -217,7 +217,7 @@ EMBEDDING_MODELS = { # type: ignore[var-annotated]
|
|||||||
|
|
||||||
MULTIMODAL_MODELS = {
|
MULTIMODAL_MODELS = {
|
||||||
# [Decoder-only]
|
# [Decoder-only]
|
||||||
"Salesforce/blip2-opt-2.7b": PPTestSettings.fast(),
|
"Salesforce/blip2-opt-6.7b": PPTestSettings.fast(),
|
||||||
"facebook/chameleon-7b": PPTestSettings.fast(),
|
"facebook/chameleon-7b": PPTestSettings.fast(),
|
||||||
"adept/fuyu-8b": PPTestSettings.fast(),
|
"adept/fuyu-8b": PPTestSettings.fast(),
|
||||||
"THUDM/glm-4v-9b": PPTestSettings.fast(),
|
"THUDM/glm-4v-9b": PPTestSettings.fast(),
|
||||||
|
@ -34,8 +34,6 @@ REQUIRES_V0_MODELS = [
|
|||||||
# V1 Test: no way to fall back for head_dim = 80
|
# V1 Test: no way to fall back for head_dim = 80
|
||||||
# https://github.com/vllm-project/vllm/issues/14524
|
# https://github.com/vllm-project/vllm/issues/14524
|
||||||
"qwen_vl",
|
"qwen_vl",
|
||||||
"h2ovl",
|
|
||||||
"blip2",
|
|
||||||
# V1 Test: not enough KV cache space in C1.
|
# V1 Test: not enough KV cache space in C1.
|
||||||
"fuyu",
|
"fuyu",
|
||||||
]
|
]
|
||||||
@ -161,7 +159,8 @@ VLM_TEST_SETTINGS = {
|
|||||||
marks=[large_gpu_mark(min_gb=64)],
|
marks=[large_gpu_mark(min_gb=64)],
|
||||||
),
|
),
|
||||||
"blip2": VLMTestInfo(
|
"blip2": VLMTestInfo(
|
||||||
models=["Salesforce/blip2-opt-2.7b"],
|
# TODO: Change back to 2.7b once head_dim = 80 is supported
|
||||||
|
models=["Salesforce/blip2-opt-6.7b"],
|
||||||
test_type=VLMTestType.IMAGE,
|
test_type=VLMTestType.IMAGE,
|
||||||
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
|
prompt_formatter=lambda img_prompt: f"Question: {img_prompt} Answer:",
|
||||||
img_idx_to_prompt=lambda idx: "",
|
img_idx_to_prompt=lambda idx: "",
|
||||||
@ -248,7 +247,8 @@ VLM_TEST_SETTINGS = {
|
|||||||
"h2ovl": VLMTestInfo(
|
"h2ovl": VLMTestInfo(
|
||||||
models = [
|
models = [
|
||||||
"h2oai/h2ovl-mississippi-800m",
|
"h2oai/h2ovl-mississippi-800m",
|
||||||
"h2oai/h2ovl-mississippi-2b",
|
# TODO: Re-enable once head_dim = 80 is supported
|
||||||
|
# "h2oai/h2ovl-mississippi-2b",
|
||||||
],
|
],
|
||||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||||
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
|
prompt_formatter=lambda img_prompt: f"<|prompt|>{img_prompt}<|end|><|answer|>", # noqa: E501
|
||||||
|
@ -259,7 +259,8 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
|
|||||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||||
# [Decoder-only]
|
# [Decoder-only]
|
||||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria"),
|
||||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
|
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b", # noqa: E501
|
||||||
|
extras={"6b": "Salesforce/blip2-opt-6.7b"}), # noqa: E501
|
||||||
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
||||||
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
||||||
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
|
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
|
||||||
|
@ -875,7 +875,8 @@ class Florence2MultiModalProcessor(
|
|||||||
Florence2MultiModalProcessor,
|
Florence2MultiModalProcessor,
|
||||||
info=Florence2ProcessingInfo,
|
info=Florence2ProcessingInfo,
|
||||||
dummy_inputs=Florence2DummyInputsBuilder)
|
dummy_inputs=Florence2DummyInputsBuilder)
|
||||||
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
|
SupportsV0Only):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -39,7 +39,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptUpdate, PromptUpdateDetails)
|
PromptUpdate, PromptUpdateDetails)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import flatten_2d_lists
|
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||||
@ -66,10 +65,13 @@ class FuyuImagePatchInputs(TypedDict):
|
|||||||
This is used to split the embeddings which has the first two dimensions
|
This is used to split the embeddings which has the first two dimensions
|
||||||
flattened just like `flat_data`.
|
flattened just like `flat_data`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
A boolean mask indicating which image embeddings correspond
|
A boolean mask indicating which image embeddings correspond
|
||||||
to patch tokens.
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size * num_images, num_embeds)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -322,16 +324,18 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
||||||
image_patches = kwargs.pop("image_patches", None)
|
image_patches = kwargs.pop("image_patches", None)
|
||||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
|
||||||
if image_patches is not None:
|
if image_patches is not None:
|
||||||
if not isinstance(image_patches, (torch.Tensor, list)):
|
if not isinstance(image_patches, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of image patches. "
|
raise ValueError("Incorrect type of image patches. "
|
||||||
f"Got type: {type(image_patches)}")
|
f"Got type: {type(image_patches)}")
|
||||||
|
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of embed_is_patch. "
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
f"Got type: {type(embed_is_patch)}")
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
image_patches_flat = flatten_bn(image_patches)
|
image_patches_flat = flatten_bn(image_patches)
|
||||||
|
embed_is_patch = flatten_bn(embed_is_patch)
|
||||||
|
|
||||||
return FuyuImagePatchInputs(
|
return FuyuImagePatchInputs(
|
||||||
type="image_patches",
|
type="image_patches",
|
||||||
@ -351,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
assert self.vision_embed_tokens is not None
|
assert self.vision_embed_tokens is not None
|
||||||
vision_embeddings_flat, _ = self.vision_embed_tokens(
|
vision_embeddings_flat, _ = self.vision_embed_tokens(
|
||||||
image_patches_flat)
|
image_patches_flat)
|
||||||
|
|
||||||
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
@ -358,13 +363,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
|
||||||
#return vision_embeddings
|
image_features = self._process_image_input(image_input)
|
||||||
return flatten_2d_lists(
|
|
||||||
scatter_patch_features(*args) for args in zip(
|
return scatter_patch_features(
|
||||||
vision_embeddings,
|
image_features,
|
||||||
image_input["embed_is_patch"],
|
image_input["embed_is_patch"],
|
||||||
))
|
)
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -613,7 +613,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self,
|
self,
|
||||||
image_input: Gemma3ImageInputs,
|
image_input: Gemma3ImageInputs,
|
||||||
) -> tuple[torch.Tensor, ...]:
|
) -> list[torch.Tensor]:
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
pixel_values = image_input["pixel_values"]
|
pixel_values = image_input["pixel_values"]
|
||||||
@ -625,7 +625,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
)
|
)
|
||||||
image_embeds = self.multi_modal_projector(image_features)
|
image_embeds = self.multi_modal_projector(image_features)
|
||||||
|
|
||||||
return image_embeds.split(num_patches.tolist())
|
return [
|
||||||
|
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
|
||||||
|
]
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
|
@ -733,7 +733,10 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
pixel_attention_mask=pixel_attention_mask,
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
|
def _process_image_input(
|
||||||
|
self,
|
||||||
|
image_input: ImageInputs,
|
||||||
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
return image_input["data"]
|
return image_input["data"]
|
||||||
|
|
||||||
@ -741,7 +744,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
image_features = self.model.connector(image_features)
|
image_features = self.model.connector(image_features)
|
||||||
|
|
||||||
num_patches = image_input["num_patches"]
|
num_patches = image_input["num_patches"]
|
||||||
return image_features.split(num_patches.tolist())
|
return [
|
||||||
|
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
|
||||||
|
]
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
|
@ -406,20 +406,21 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
h, w)
|
h, w)
|
||||||
stacked_embeddings = self._video_pixels_to_features(
|
stacked_embeddings = self._video_pixels_to_features(
|
||||||
self.vision_tower, stacked_pixels)
|
self.vision_tower, stacked_pixels)
|
||||||
return stacked_embeddings.view(b, num_frames,
|
embeds = stacked_embeddings.view(b, num_frames,
|
||||||
*stacked_embeddings.shape[1:])
|
*stacked_embeddings.shape[1:])
|
||||||
|
|
||||||
elif is_list_of(video_pixels, torch.Tensor):
|
elif is_list_of(video_pixels, torch.Tensor):
|
||||||
frames_per_videos = [v.shape[0] for v in video_pixels]
|
frames_per_videos = [v.shape[0] for v in video_pixels]
|
||||||
stacked_pixels = torch.cat(video_pixels, dim=0)
|
stacked_pixels = torch.cat(video_pixels, dim=0)
|
||||||
stacked_embeddings = self._video_pixels_to_features(
|
stacked_embeddings = self._video_pixels_to_features(
|
||||||
self.vision_tower, stacked_pixels)
|
self.vision_tower, stacked_pixels)
|
||||||
return torch.split(stacked_embeddings, frames_per_videos, dim=0)
|
embeds = torch.split(stacked_embeddings, frames_per_videos, dim=0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported type of video input {type(video_pixels)}")
|
f"Unsupported type of video input {type(video_pixels)}")
|
||||||
|
|
||||||
|
return [e.flatten(0, 1) for e in embeds]
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||||
|
@ -919,8 +919,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
image_features_flat = self.get_vision_hidden_states(image_input)
|
image_features_flat = self.get_vision_hidden_states(image_input)
|
||||||
|
|
||||||
# Reconstruct the batch dimension
|
num_slices = image_input["num_slices"]
|
||||||
return image_features_flat.split(image_input["num_slices"].tolist())
|
return [
|
||||||
|
e.flatten(0, 1)
|
||||||
|
for e in image_features_flat.split(num_slices.tolist())
|
||||||
|
]
|
||||||
|
|
||||||
def _process_multimodal_inputs(self, modalities: dict):
|
def _process_multimodal_inputs(self, modalities: dict):
|
||||||
# The result multimodal_embeddings is tuple of tensors, with each
|
# The result multimodal_embeddings is tuple of tensors, with each
|
||||||
|
@ -204,7 +204,7 @@ def scatter_patch_features(
|
|||||||
(e_is_patch.shape[0], patches_one.shape[-1]),
|
(e_is_patch.shape[0], patches_one.shape[-1]),
|
||||||
fill_value=torch.nan,
|
fill_value=torch.nan,
|
||||||
)
|
)
|
||||||
embed_one[e_is_patch] = patches_one.flatten(0, -2)
|
embed_one[e_is_patch] = patches_one
|
||||||
return embed_one
|
return embed_one
|
||||||
|
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -41,6 +41,8 @@ from vllm.v1.utils import bind_kv_cache
|
|||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
|
from .utils import sanity_check_mm_encoder_outputs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
|
|
||||||
@ -867,6 +869,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
curr_group_outputs = self.model.get_multimodal_embeddings(
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||||
**batched_mm_inputs)
|
**batched_mm_inputs)
|
||||||
|
|
||||||
|
sanity_check_mm_encoder_outputs(
|
||||||
|
curr_group_outputs,
|
||||||
|
expected_num_items=len(grouped_mm_inputs),
|
||||||
|
)
|
||||||
|
|
||||||
for output in curr_group_outputs:
|
for output in curr_group_outputs:
|
||||||
encoder_outputs.append(output)
|
encoder_outputs.append(output)
|
||||||
|
|
||||||
@ -1490,12 +1497,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Run multimodal encoder.
|
# Run multimodal encoder.
|
||||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||||
**batched_dummy_mm_inputs)
|
**batched_dummy_mm_inputs)
|
||||||
assert len(dummy_encoder_outputs) == max_num_mm_items, (
|
|
||||||
"Expected dimension 0 of encoder outputs to match the number "
|
sanity_check_mm_encoder_outputs(
|
||||||
f"of multimodal data items: {max_num_mm_items}, got "
|
dummy_encoder_outputs,
|
||||||
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
|
expected_num_items=max_num_mm_items,
|
||||||
"due to the 'get_multimodal_embeddings' method of the model "
|
)
|
||||||
"not implemented correctly.")
|
|
||||||
|
|
||||||
# Cache the dummy encoder outputs.
|
# Cache the dummy encoder outputs.
|
||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||||
|
@ -37,6 +37,8 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
|||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
|
from .utils import sanity_check_mm_encoder_outputs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
@ -512,6 +514,11 @@ class TPUModelRunner:
|
|||||||
curr_group_outputs = self.model.get_multimodal_embeddings(
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||||
**batched_mm_inputs)
|
**batched_mm_inputs)
|
||||||
|
|
||||||
|
sanity_check_mm_encoder_outputs(
|
||||||
|
curr_group_outputs,
|
||||||
|
expected_num_items=len(grouped_mm_inputs),
|
||||||
|
)
|
||||||
|
|
||||||
for output in curr_group_outputs:
|
for output in curr_group_outputs:
|
||||||
encoder_outputs.append(output)
|
encoder_outputs.append(output)
|
||||||
|
|
||||||
|
29
vllm/v1/worker/utils.py
Normal file
29
vllm/v1/worker/utils.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def sanity_check_mm_encoder_outputs(
|
||||||
|
mm_embeddings: object,
|
||||||
|
expected_num_items: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Perform sanity checks for the result of
|
||||||
|
:meth:`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`.
|
||||||
|
"""
|
||||||
|
assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
|
||||||
|
"Expected multimodal embeddings to be a list/tuple of 2D tensors, "
|
||||||
|
f"or a single 3D tensor, but got {type(mm_embeddings)} "
|
||||||
|
"instead. This is most likely due to incorrect implementation "
|
||||||
|
"of the model's `get_multimodal_embeddings` method.")
|
||||||
|
|
||||||
|
assert len(mm_embeddings) == expected_num_items, (
|
||||||
|
"Expected number of multimodal embeddings to match number of "
|
||||||
|
f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
|
||||||
|
"instead. This is most likely due to incorrect implementation "
|
||||||
|
"of the model's `get_multimodal_embeddings` method.")
|
||||||
|
|
||||||
|
assert all(e.ndim == 2 for e in mm_embeddings), (
|
||||||
|
"Expected multimodal embeddings to be a sequence of 2D tensors, "
|
||||||
|
f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
|
||||||
|
"instead. This is most likely due to incorrect implementation "
|
||||||
|
"of the model's `get_multimodal_embeddings` method.")
|
Loading…
x
Reference in New Issue
Block a user