Cyrus Leung 09e974d483
[Bugfix] Check dimensions of multimodal embeddings in V1 (#15816)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-03-31 09:01:35 -07:00

434 lines
16 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
# Copyright 2023 The vLLM team.
# Copyright 2023 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
FuyuProcessor)
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
_NEWLINE_TOKEN_ID = 71019
class FuyuImagePatchInputs(TypedDict):
type: Literal["image_patches"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
patches_per_image: list[int]
"""
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class FuyuProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(FuyuConfig)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(FuyuProcessor, **kwargs)
def get_image_processor(self) -> FuyuImageProcessor:
return self.get_hf_processor().image_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_ncols, max_nrows = self.get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def get_image_feature_grid_size(
self,
*,
image_width: int,
image_height: int,
) -> tuple[int, int]:
image_processor = self.get_image_processor()
target_width = image_processor.size["width"]
target_height = image_processor.size["height"]
patch_width = image_processor.patch_size["width"]
patch_height = image_processor.patch_size["height"]
if not (image_width <= target_width and image_height <= target_height):
height_scale_factor = target_height / image_height
width_scale_factor = target_width / image_width
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
image_height = int(image_height * optimal_scale_factor)
image_width = int(image_width * optimal_scale_factor)
ncols = math.ceil(image_width / patch_width)
nrows = math.ceil(image_height / patch_height)
return ncols, nrows
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"])
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# Avoid warning from HF logger for text-only input
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
image_patches = processed_outputs.get("image_patches")
if image_patches is not None:
images = mm_data["images"]
assert isinstance(images, list)
# Original output: (1, num_images, Pn, Px * Py * C)
# New output: (num_images, Pn, Px * Py * C)
assert (isinstance(image_patches, list)
and len(image_patches) == 1)
assert (isinstance(image_patches[0], torch.Tensor)
and len(image_patches[0]) == len(images))
processed_outputs["image_patches"] = image_patches[0]
# get patch grid size for each image
embed_is_patch = []
for image in images:
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image.width,
image_height=image.height,
)
mask = torch.tensor(([True] * ncols + [False]) * nrows)
embed_is_patch.append(mask)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds boa_token_id
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
boa_token_id = vocab["<0x04>"]
return prompt_tokens + [boa_token_id]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
assert isinstance(bos_token_id, int)
tokenizer = self.info.get_tokenizer()
eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int)
def get_replacement_fuyu(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image_size.width,
image_height=image_size.height,
)
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
return [
PromptReplacement(
modality="image",
target=[eot_token_id],
replacement=get_replacement_fuyu,
)
]
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor,
info=FuyuProcessingInfo,
dummy_inputs=FuyuDummyInputsBuilder)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.vocab_size = config.text_config.vocab_size
self.image_token_id = _IMAGE_TOKEN_ID
self.image_feature_size = config.patch_size**2 * config.num_channels
self.vision_embed_tokens = ColumnParallelLinear(
self.image_feature_size,
config.hidden_size,
quant_config=quant_config,
gather_output=True,
)
self.language_model = PersimmonForCausalLM(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.patch_size
num_channels = self.config.num_channels
expected_dims = num_channels * h * w
def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f"per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None)
if image_patches is not None:
if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches)
embed_is_patch = flatten_bn(embed_is_patch)
return FuyuImagePatchInputs(
type="image_patches",
flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
embed_is_patch=embed_is_patch,
)
return None
def _process_image_input(
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None
vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds,
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.language_model.logits_processor(
self.language_model.lm_head, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.language_model.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)