[VLM][Bugfix] Multi-modal processor compatible with V1 multi-input (#11674)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-02 17:00:00 +08:00 committed by GitHub
parent a115ac46b5
commit 23c1b10a4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 151 additions and 168 deletions

View File

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
final)
import numpy as np
import torch
@ -11,7 +12,7 @@ from PIL.Image import Image
from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, is_list_of, json_map_leaves
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
_T = TypeVar("_T")
@ -160,11 +161,8 @@ A dictionary containing nested tensors which have been batched via
@dataclass(frozen=True)
class MultiModalFieldItem:
"""
Contains metadata and data in :class:`MultiModalKwargs`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
class MultiModalFieldElem:
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
field: "BaseMultiModalField"
data: NestedTensors
@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
raise NotImplementedError
def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
return MultiModalFieldItem(self, data)
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
return MultiModalFieldElem(self, data)
def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
"""Merge multiple instances of :class:`MultiModalFieldItem` together."""
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
fields = [item.field for item in batch]
if len(set(fields)) > 1:
raise ValueError(f"Cannot merge different {fields=}")
data = self._reduce_data([item.data for item in batch])
return self._build_item(data)
return self._build_elem(data)
@dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an item is obtained by
directly indexing into the first dimension of the underlying data.
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by indexing into the first dimension of the underlying data.
"""
def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
return [self._build_item(item) for item in batch]
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
return [self._build_elem(item) for item in batch]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape
if all(item.shape == first_shape for item in batch):
if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch)
return batch
@ -222,24 +220,24 @@ class MultiModalBatchedField(BaseMultiModalField):
@dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField):
"""
A :class:`BaseMultiModalField` implementation where an item is obtained by
slicing along the first dimension of the underlying data.
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by slicing along the first dimension of the underlying data.
"""
def build_items(
def build_elems(
self,
batch: NestedTensors,
slices: Sequence[slice],
) -> list[MultiModalFieldItem]:
return [self._build_item(batch[slice_]) for slice_ in slices]
) -> list[MultiModalFieldElem]:
return [self._build_elem(batch[slice_]) for slice_ in slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape
if all(item.shape[1:] == first_shape[1:] for item in batch):
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
return torch.concat(batch)
return [elem for item in batch for elem in item]
return [e for elem in batch for e in elem]
class MultiModalFieldConfig:
@ -267,115 +265,111 @@ class MultiModalFieldConfig:
) -> None:
super().__init__()
self._field_cls = field_cls
self._modality = modality
self._field_config = field_config
self.field_cls = field_cls
self.modality = modality
self.field_config = field_config
def build_items(
def build_elems(
self,
key: str,
batch: NestedTensors,
) -> list[MultiModalFieldItem]:
field = self._field_cls(key=key, modality=self._modality)
return field.build_items(batch, **self._field_config) # type: ignore
) -> Sequence[MultiModalFieldElem]:
field = self.field_cls(key=key, modality=self.modality)
return field.build_elems(batch, **self.field_config) # type: ignore
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""
A collection of :class:`MultiModalFieldElem`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
@staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
@property
def modality(self) -> str:
modalities = {elem.field.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities))
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`.
The metadata :code:`items_by_key` defines how to split batched keyword
arguments corresponding to each data item in :class:`MultiModalDataItems`:
- For a keyword argument, we can access the :code:`i` th item in the batch
via :code:`items_by_key[key][i]`.
- We can gather the keyword arguments belonging to a modality by finding
the keys with items that belong to that modality, then accessing
the :code:`i` th item in the batch for each such key.
Example:
.. code-block:: python
# All items belong to the "image" modality
items_by_key={
"pixel_values": [a, b, c, d], # "image" modality
"image_grid_thw": [e, f, g, h], # "image" modality
"pixel_values_video": [h, i, j], # "video" modality
"video_grid_thw": [k, l, m], # "video" modality
}
- The keyword arguments belonging to the first image are
:code:`{"pixel_values": a, "image_grid_thw": e}`.
- The keyword arguments belonging to the second video are
:code:`{"pixel_values_video": i, "video_grid_thw": l}`.
The metadata :code:`items` enables us to obtain the keyword arguments
corresponding to each data item in :class:`MultiModalDataItems`, via
:meth:`get_item` and :meth:`get_items`.
"""
@staticmethod
def from_hf_inputs(
hf_inputs: BatchFeature,
config_by_key: Mapping[str, MultiModalFieldConfig],
*,
enable_sanity_checks: bool = False,
):
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM
items_by_key = {
key: config.build_items(key, batch)
for key, config in config_by_key.items()
if (batch := hf_inputs.get(key)) is not None
}
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
keys_by_modality = defaultdict[str, set[str]](set)
for key, config in config_by_key.items():
batch = hf_inputs.get(key)
if batch is not None:
elems = config.build_elems(key, batch)
if len(elems) > 0:
elems_by_key[key] = elems
keys_by_modality[config.modality].add(key)
return MultiModalKwargs.from_items_by_key(
items_by_key,
enable_sanity_checks=enable_sanity_checks,
)
items = list[MultiModalKwargsItem]()
for modality, keys in keys_by_modality.items():
elems_in_modality = {k: elems_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
if len(set(batch_sizes.values())) > 1:
raise ValueError(
f"Cannot merge different batch sizes for {modality=}! "
f"Found: {batch_sizes=}")
batch_size = next(iter(batch_sizes.values()))
for item_idx in range(batch_size):
elems = [v[item_idx] for v in elems_in_modality.values()]
items.append(MultiModalKwargsItem.from_elems(elems))
return MultiModalKwargs.from_items(items)
@staticmethod
def from_items_by_key(
items_by_key: Mapping[str, list[MultiModalFieldItem]],
*,
enable_sanity_checks: bool = False,
) -> "MultiModalKwargs":
def from_items(items: Sequence[MultiModalKwargsItem]):
"""Construct a new :class:`MultiModalKwargs` from multiple items."""
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for item in items:
for key, elem in item.items():
elems_by_key[key].append(elem)
data = {
key: items[0].field.reduce(items).data
for key, items in items_by_key.items() if len(items) > 0
key: elems[0].field.reduce(elems).data
for key, elems in elems_by_key.items() if len(elems) > 0
}
return MultiModalKwargs(data,
items_by_key=items_by_key,
enable_sanity_checks=enable_sanity_checks)
return MultiModalKwargs(data, items=items)
def __init__(
self,
data: Mapping[str, NestedTensors],
*,
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
enable_sanity_checks: bool = False,
items: Optional[Sequence[MultiModalKwargsItem]] = None,
) -> None:
super().__init__(data)
# Shallow copy to avoid footgun in case a defaultdict is passed in
self._items_by_key = dict(items_by_key)
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
self._items_by_modality = dict(items_by_modality)
keys_by_modality = defaultdict[str, set[str]](set)
for key, items in items_by_key.items():
for item in items:
keys_by_modality[item.field.modality].add(key)
self._keys_by_modality = dict(keys_by_modality)
if enable_sanity_checks:
for modality, keys in keys_by_modality.items():
items_in_modality = {k: items_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size
for bs in batch_sizes.values()), dict(
modality=modality,
batch_sizes=batch_sizes,
items_by_key=items_by_key)
@property
def modalities(self):
return self._items_by_modality.keys()
@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
@ -452,58 +446,44 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
return False
if self._items_by_key != other._items_by_key:
if self._items_by_modality != other._items_by_modality:
return False
ks = self.keys()
return (ks == other.keys()
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
return self._items_by_key[key][item_index]
def _validate_modality(self, method_name: str, modality: str) -> None:
if not self._items_by_modality:
raise RuntimeError(
f"`{method_name}` is not supported when "
"MultiModalKwargs is not initialized with `items`")
def get_items_by_modality(
self,
modality: str,
item_index: int,
) -> Mapping[str, MultiModalFieldItem]:
if modality not in self._items_by_modality:
available_modalities = set(self._items_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
def get_item_count(self, modality: str) -> int:
"""Get the number of items belonging to a modality."""
self._validate_modality("get_item_count", modality)
return len(self._items_by_modality[modality])
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
"""
Get the keyword arguments corresponding to an item identified by
its modality and index.
"""
if modality not in self._keys_by_modality:
available_modalities = set(self._keys_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
self._validate_modality("get_item", modality)
return self._items_by_modality[modality][item_index]
keys_to_gather = self._keys_by_modality[modality]
return {
key: self.get_item(key, item_index)
for key in keys_to_gather if key in self
}
@staticmethod
def from_items_by_modality(
items_by_modality: Mapping[str, list[Mapping[str,
MultiModalFieldItem]]],
*,
enable_sanity_checks: bool = False,
) -> "MultiModalKwargs":
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
"""
Construct a new :class:`MultiModalKwargs` from multiple items returned
by :meth:`get_fields_by_modality`.
Get the keyword arguments corresponding to each item belonging to
a modality.
"""
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
for fields in items_by_modality.values():
for field in fields:
for k, v in field.items():
items_by_key[k].append(v)
return MultiModalKwargs.from_items_by_key(
items_by_key,
enable_sanity_checks=enable_sanity_checks,
)
self._validate_modality("get_items", modality)
return self._items_by_modality[modality]
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]

View File

@ -20,8 +20,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs,
PlaceholderRange)
MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
logger = init_logger(__name__)
@ -496,8 +496,7 @@ class ProcessingCache:
# DEBUG: Set to None to disable
self.debug_cache_hit_ratio_steps: Optional[int] = None
self._cache = LRUCache[str, Mapping[str,
MultiModalFieldItem]](capacity)
self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps
@ -565,7 +564,7 @@ class ProcessingCache:
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> Optional[Mapping[str, MultiModalFieldItem]]:
) -> Optional[MultiModalKwargsItem]:
"""
Get a processed multi-modal item from the cache
according to its dependencies, including:
@ -588,7 +587,7 @@ class ProcessingCache:
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
output_kwargs: Mapping[str, MultiModalFieldItem],
output_kwargs: MultiModalKwargsItem,
) -> None:
"""
Put a processed multi-modal item into the cache
@ -784,7 +783,6 @@ class BaseMultiModalProcessor(ABC):
mm_kwargs = MultiModalKwargs.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
enable_sanity_checks=self.enable_sanity_checks,
)
return prompt_ids, mm_kwargs
@ -846,7 +844,7 @@ class BaseMultiModalProcessor(ABC):
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
mm_maybe_cached_field_items = {
mm_maybe_cached_kw_items = {
modality: [
cache.get(model_id, modality, item, hf_processor_mm_kwargs)
for item in items
@ -855,8 +853,9 @@ class BaseMultiModalProcessor(ABC):
}
mm_missing_idxs = {
modality: [idx for idx, out in enumerate(fields) if out is None]
for modality, fields in mm_maybe_cached_field_items.items()
modality:
[idx for idx, item in enumerate(kw_items) if item is None]
for modality, kw_items in mm_maybe_cached_kw_items.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
@ -875,14 +874,11 @@ class BaseMultiModalProcessor(ABC):
for modality in mm_missing_data_items
}
mm_merged_field_items = dict[str, list[Mapping[str,
MultiModalFieldItem]]]()
for modality, modal_items_lst in mm_maybe_cached_field_items.items():
merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]()
for idx, modal_items in enumerate(modal_items_lst):
if modal_items is None:
modal_items = mm_missing_kwargs.get_items_by_modality(
merged_kw_items = list[MultiModalKwargsItem]()
for modality, kw_items in mm_maybe_cached_kw_items.items():
for idx, kw_item in enumerate(kw_items):
if kw_item is None:
kw_item = mm_missing_kwargs.get_item(
modality,
mm_missing_next_idx[modality],
)
@ -892,14 +888,12 @@ class BaseMultiModalProcessor(ABC):
modality,
mm_data_items[modality][idx],
hf_processor_mm_kwargs,
modal_items,
kw_item,
)
mm_missing_next_idx[modality] += 1
merged_modal_items_lst.append(modal_items)
mm_merged_field_items[modality] = merged_modal_items_lst
merged_kw_items.append(kw_item)
if self.enable_sanity_checks:
mm_missing_counts = mm_missing_data_items.get_all_counts()
@ -909,10 +903,7 @@ class BaseMultiModalProcessor(ABC):
mm_missing_next_idx=mm_missing_next_idx,
mm_missing_counts=mm_missing_counts)
mm_kwargs = MultiModalKwargs.from_items_by_modality(
mm_merged_field_items,
enable_sanity_checks=self.enable_sanity_checks,
)
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
if self.enable_sanity_checks:
mm_item_counts = mm_data_items.get_all_counts()
@ -920,7 +911,7 @@ class BaseMultiModalProcessor(ABC):
for modality, item_count in mm_item_counts.items():
for item_idx in range(item_count):
try:
mm_kwargs.get_items_by_modality(modality, item_idx)
mm_kwargs.get_item(modality, item_idx)
except Exception as e:
# Make it easy to set a breakpoint in the debugger
raise e

View File

@ -113,15 +113,27 @@ class Processor:
# For merged preprocessor, mm_data is already mm_inputs
precomputed_mm_inputs = None
if isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
precomputed_mm_inputs = [decoder_inputs.multi_modal_data]
decoder_mm_data = decoder_inputs.multi_modal_data
if isinstance(decoder_mm_data, MultiModalKwargs):
# The output of merged multi-modal processor (`decoder_mm_data`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
precomputed_mm_inputs = [
MultiModalKwargs.from_items([item])
for modality in decoder_mm_data.modalities
for item in decoder_mm_data.get_items(modality)
]
# Apply MM mapper
mm_inputs = None
if len(decoder_inputs.multi_modal_data) > 0:
if len(decoder_mm_data) > 0:
mm_inputs = self.mm_input_mapper_client.process_inputs(
decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
decoder_mm_data,
mm_hashes,
decoder_inputs.mm_processor_kwargs,
precomputed_mm_inputs,
)
return EngineCoreRequest(
request_id,