[VLM][Bugfix] Multi-modal processor compatible with V1 multi-input (#11674)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
a115ac46b5
commit
23c1b10a4c
@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
|
||||
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
|
||||
final)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -11,7 +12,7 @@ from PIL.Image import Image
|
||||
from transformers import BatchFeature
|
||||
from typing_extensions import NotRequired, TypeAlias
|
||||
|
||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
||||
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@ -160,11 +161,8 @@ A dictionary containing nested tensors which have been batched via
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalFieldItem:
|
||||
"""
|
||||
Contains metadata and data in :class:`MultiModalKwargs`
|
||||
corresponding to a data item in :class:`MultiModalDataItems`.
|
||||
"""
|
||||
class MultiModalFieldElem:
|
||||
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
|
||||
field: "BaseMultiModalField"
|
||||
data: NestedTensors
|
||||
|
||||
@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
raise NotImplementedError
|
||||
|
||||
def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
|
||||
return MultiModalFieldItem(self, data)
|
||||
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
|
||||
return MultiModalFieldElem(self, data)
|
||||
|
||||
def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
|
||||
"""Merge multiple instances of :class:`MultiModalFieldItem` together."""
|
||||
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
|
||||
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
|
||||
fields = [item.field for item in batch]
|
||||
if len(set(fields)) > 1:
|
||||
raise ValueError(f"Cannot merge different {fields=}")
|
||||
|
||||
data = self._reduce_data([item.data for item in batch])
|
||||
|
||||
return self._build_item(data)
|
||||
return self._build_elem(data)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalBatchedField(BaseMultiModalField):
|
||||
"""
|
||||
A :class:`BaseMultiModalField` implementation where an item is obtained by
|
||||
directly indexing into the first dimension of the underlying data.
|
||||
A :class:`BaseMultiModalField` implementation where an element in the batch
|
||||
is obtained by indexing into the first dimension of the underlying data.
|
||||
"""
|
||||
|
||||
def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
|
||||
return [self._build_item(item) for item in batch]
|
||||
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
|
||||
return [self._build_elem(item) for item in batch]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
first_shape = batch[0].shape
|
||||
if all(item.shape == first_shape for item in batch):
|
||||
if all(elem.shape == first_shape for elem in batch):
|
||||
return torch.stack(batch)
|
||||
|
||||
return batch
|
||||
@ -222,24 +220,24 @@ class MultiModalBatchedField(BaseMultiModalField):
|
||||
@dataclass(frozen=True)
|
||||
class MultiModalFlatField(BaseMultiModalField):
|
||||
"""
|
||||
A :class:`BaseMultiModalField` implementation where an item is obtained by
|
||||
slicing along the first dimension of the underlying data.
|
||||
A :class:`BaseMultiModalField` implementation where an element in the batch
|
||||
is obtained by slicing along the first dimension of the underlying data.
|
||||
"""
|
||||
|
||||
def build_items(
|
||||
def build_elems(
|
||||
self,
|
||||
batch: NestedTensors,
|
||||
slices: Sequence[slice],
|
||||
) -> list[MultiModalFieldItem]:
|
||||
return [self._build_item(batch[slice_]) for slice_ in slices]
|
||||
) -> list[MultiModalFieldElem]:
|
||||
return [self._build_elem(batch[slice_]) for slice_ in slices]
|
||||
|
||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||
first_shape = batch[0].shape
|
||||
if all(item.shape[1:] == first_shape[1:] for item in batch):
|
||||
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
|
||||
return torch.concat(batch)
|
||||
|
||||
return [elem for item in batch for elem in item]
|
||||
return [e for elem in batch for e in elem]
|
||||
|
||||
|
||||
class MultiModalFieldConfig:
|
||||
@ -267,115 +265,111 @@ class MultiModalFieldConfig:
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._field_cls = field_cls
|
||||
self._modality = modality
|
||||
self._field_config = field_config
|
||||
self.field_cls = field_cls
|
||||
self.modality = modality
|
||||
self.field_config = field_config
|
||||
|
||||
def build_items(
|
||||
def build_elems(
|
||||
self,
|
||||
key: str,
|
||||
batch: NestedTensors,
|
||||
) -> list[MultiModalFieldItem]:
|
||||
field = self._field_cls(key=key, modality=self._modality)
|
||||
return field.build_items(batch, **self._field_config) # type: ignore
|
||||
) -> Sequence[MultiModalFieldElem]:
|
||||
field = self.field_cls(key=key, modality=self.modality)
|
||||
return field.build_elems(batch, **self.field_config) # type: ignore
|
||||
|
||||
|
||||
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
"""
|
||||
A collection of :class:`MultiModalFieldElem`
|
||||
corresponding to a data item in :class:`MultiModalDataItems`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
modalities = {elem.field.modality for elem in self.data.values()}
|
||||
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
||||
return next(iter(modalities))
|
||||
|
||||
|
||||
# NOTE: UserDict is for V0 compatibility.
|
||||
# V1 should access individual items via `get_item`.
|
||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
:meth:`~torch.nn.Module.forward`.
|
||||
|
||||
The metadata :code:`items_by_key` defines how to split batched keyword
|
||||
arguments corresponding to each data item in :class:`MultiModalDataItems`:
|
||||
|
||||
- For a keyword argument, we can access the :code:`i` th item in the batch
|
||||
via :code:`items_by_key[key][i]`.
|
||||
- We can gather the keyword arguments belonging to a modality by finding
|
||||
the keys with items that belong to that modality, then accessing
|
||||
the :code:`i` th item in the batch for each such key.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# All items belong to the "image" modality
|
||||
items_by_key={
|
||||
"pixel_values": [a, b, c, d], # "image" modality
|
||||
"image_grid_thw": [e, f, g, h], # "image" modality
|
||||
"pixel_values_video": [h, i, j], # "video" modality
|
||||
"video_grid_thw": [k, l, m], # "video" modality
|
||||
}
|
||||
|
||||
- The keyword arguments belonging to the first image are
|
||||
:code:`{"pixel_values": a, "image_grid_thw": e}`.
|
||||
- The keyword arguments belonging to the second video are
|
||||
:code:`{"pixel_values_video": i, "video_grid_thw": l}`.
|
||||
The metadata :code:`items` enables us to obtain the keyword arguments
|
||||
corresponding to each data item in :class:`MultiModalDataItems`, via
|
||||
:meth:`get_item` and :meth:`get_items`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_hf_inputs(
|
||||
hf_inputs: BatchFeature,
|
||||
config_by_key: Mapping[str, MultiModalFieldConfig],
|
||||
*,
|
||||
enable_sanity_checks: bool = False,
|
||||
):
|
||||
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
|
||||
# We assume that those fields are not used in vLLM
|
||||
items_by_key = {
|
||||
key: config.build_items(key, batch)
|
||||
for key, config in config_by_key.items()
|
||||
if (batch := hf_inputs.get(key)) is not None
|
||||
}
|
||||
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
|
||||
keys_by_modality = defaultdict[str, set[str]](set)
|
||||
for key, config in config_by_key.items():
|
||||
batch = hf_inputs.get(key)
|
||||
if batch is not None:
|
||||
elems = config.build_elems(key, batch)
|
||||
if len(elems) > 0:
|
||||
elems_by_key[key] = elems
|
||||
keys_by_modality[config.modality].add(key)
|
||||
|
||||
return MultiModalKwargs.from_items_by_key(
|
||||
items_by_key,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
items = list[MultiModalKwargsItem]()
|
||||
for modality, keys in keys_by_modality.items():
|
||||
elems_in_modality = {k: elems_by_key[k] for k in keys}
|
||||
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
|
||||
|
||||
if len(set(batch_sizes.values())) > 1:
|
||||
raise ValueError(
|
||||
f"Cannot merge different batch sizes for {modality=}! "
|
||||
f"Found: {batch_sizes=}")
|
||||
|
||||
batch_size = next(iter(batch_sizes.values()))
|
||||
for item_idx in range(batch_size):
|
||||
elems = [v[item_idx] for v in elems_in_modality.values()]
|
||||
items.append(MultiModalKwargsItem.from_elems(elems))
|
||||
|
||||
return MultiModalKwargs.from_items(items)
|
||||
|
||||
@staticmethod
|
||||
def from_items_by_key(
|
||||
items_by_key: Mapping[str, list[MultiModalFieldItem]],
|
||||
*,
|
||||
enable_sanity_checks: bool = False,
|
||||
) -> "MultiModalKwargs":
|
||||
def from_items(items: Sequence[MultiModalKwargsItem]):
|
||||
"""Construct a new :class:`MultiModalKwargs` from multiple items."""
|
||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||
for item in items:
|
||||
for key, elem in item.items():
|
||||
elems_by_key[key].append(elem)
|
||||
|
||||
data = {
|
||||
key: items[0].field.reduce(items).data
|
||||
for key, items in items_by_key.items() if len(items) > 0
|
||||
key: elems[0].field.reduce(elems).data
|
||||
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||
}
|
||||
|
||||
return MultiModalKwargs(data,
|
||||
items_by_key=items_by_key,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
return MultiModalKwargs(data, items=items)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, NestedTensors],
|
||||
*,
|
||||
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
|
||||
enable_sanity_checks: bool = False,
|
||||
items: Optional[Sequence[MultiModalKwargsItem]] = None,
|
||||
) -> None:
|
||||
super().__init__(data)
|
||||
|
||||
# Shallow copy to avoid footgun in case a defaultdict is passed in
|
||||
self._items_by_key = dict(items_by_key)
|
||||
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
|
||||
self._items_by_modality = dict(items_by_modality)
|
||||
|
||||
keys_by_modality = defaultdict[str, set[str]](set)
|
||||
for key, items in items_by_key.items():
|
||||
for item in items:
|
||||
keys_by_modality[item.field.modality].add(key)
|
||||
|
||||
self._keys_by_modality = dict(keys_by_modality)
|
||||
|
||||
if enable_sanity_checks:
|
||||
for modality, keys in keys_by_modality.items():
|
||||
items_in_modality = {k: items_by_key[k] for k in keys}
|
||||
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
|
||||
batch_size = next(iter(batch_sizes.values()), 0)
|
||||
assert all(bs == batch_size
|
||||
for bs in batch_sizes.values()), dict(
|
||||
modality=modality,
|
||||
batch_sizes=batch_sizes,
|
||||
items_by_key=items_by_key)
|
||||
@property
|
||||
def modalities(self):
|
||||
return self._items_by_modality.keys()
|
||||
|
||||
@staticmethod
|
||||
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
||||
@ -452,58 +446,44 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if self._items_by_key != other._items_by_key:
|
||||
if self._items_by_modality != other._items_by_modality:
|
||||
return False
|
||||
|
||||
ks = self.keys()
|
||||
return (ks == other.keys()
|
||||
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
|
||||
|
||||
def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
|
||||
return self._items_by_key[key][item_index]
|
||||
def _validate_modality(self, method_name: str, modality: str) -> None:
|
||||
if not self._items_by_modality:
|
||||
raise RuntimeError(
|
||||
f"`{method_name}` is not supported when "
|
||||
"MultiModalKwargs is not initialized with `items`")
|
||||
|
||||
def get_items_by_modality(
|
||||
self,
|
||||
modality: str,
|
||||
item_index: int,
|
||||
) -> Mapping[str, MultiModalFieldItem]:
|
||||
if modality not in self._items_by_modality:
|
||||
available_modalities = set(self._items_by_modality.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
|
||||
def get_item_count(self, modality: str) -> int:
|
||||
"""Get the number of items belonging to a modality."""
|
||||
self._validate_modality("get_item_count", modality)
|
||||
return len(self._items_by_modality[modality])
|
||||
|
||||
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
|
||||
"""
|
||||
Get the keyword arguments corresponding to an item identified by
|
||||
its modality and index.
|
||||
"""
|
||||
if modality not in self._keys_by_modality:
|
||||
available_modalities = set(self._keys_by_modality.keys())
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {available_modalities}")
|
||||
self._validate_modality("get_item", modality)
|
||||
return self._items_by_modality[modality][item_index]
|
||||
|
||||
keys_to_gather = self._keys_by_modality[modality]
|
||||
|
||||
return {
|
||||
key: self.get_item(key, item_index)
|
||||
for key in keys_to_gather if key in self
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_items_by_modality(
|
||||
items_by_modality: Mapping[str, list[Mapping[str,
|
||||
MultiModalFieldItem]]],
|
||||
*,
|
||||
enable_sanity_checks: bool = False,
|
||||
) -> "MultiModalKwargs":
|
||||
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
|
||||
"""
|
||||
Construct a new :class:`MultiModalKwargs` from multiple items returned
|
||||
by :meth:`get_fields_by_modality`.
|
||||
Get the keyword arguments corresponding to each item belonging to
|
||||
a modality.
|
||||
"""
|
||||
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
|
||||
for fields in items_by_modality.values():
|
||||
for field in fields:
|
||||
for k, v in field.items():
|
||||
items_by_key[k].append(v)
|
||||
|
||||
return MultiModalKwargs.from_items_by_key(
|
||||
items_by_key,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
self._validate_modality("get_items", modality)
|
||||
return self._items_by_modality[modality]
|
||||
|
||||
|
||||
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
|
||||
|
@ -20,8 +20,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
||||
|
||||
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs,
|
||||
PlaceholderRange)
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from .parse import MultiModalDataItems, MultiModalDataParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -496,8 +496,7 @@ class ProcessingCache:
|
||||
# DEBUG: Set to None to disable
|
||||
self.debug_cache_hit_ratio_steps: Optional[int] = None
|
||||
|
||||
self._cache = LRUCache[str, Mapping[str,
|
||||
MultiModalFieldItem]](capacity)
|
||||
self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
|
||||
|
||||
def _maybe_log_cache_stats(self) -> None:
|
||||
steps = self.debug_cache_hit_ratio_steps
|
||||
@ -565,7 +564,7 @@ class ProcessingCache:
|
||||
modality: str,
|
||||
input_item: object,
|
||||
input_kwargs: Mapping[str, object],
|
||||
) -> Optional[Mapping[str, MultiModalFieldItem]]:
|
||||
) -> Optional[MultiModalKwargsItem]:
|
||||
"""
|
||||
Get a processed multi-modal item from the cache
|
||||
according to its dependencies, including:
|
||||
@ -588,7 +587,7 @@ class ProcessingCache:
|
||||
modality: str,
|
||||
input_item: object,
|
||||
input_kwargs: Mapping[str, object],
|
||||
output_kwargs: Mapping[str, MultiModalFieldItem],
|
||||
output_kwargs: MultiModalKwargsItem,
|
||||
) -> None:
|
||||
"""
|
||||
Put a processed multi-modal item into the cache
|
||||
@ -784,7 +783,6 @@ class BaseMultiModalProcessor(ABC):
|
||||
mm_kwargs = MultiModalKwargs.from_hf_inputs(
|
||||
processed_data,
|
||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
||||
enable_sanity_checks=self.enable_sanity_checks,
|
||||
)
|
||||
|
||||
return prompt_ids, mm_kwargs
|
||||
@ -846,7 +844,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
mm_maybe_cached_field_items = {
|
||||
mm_maybe_cached_kw_items = {
|
||||
modality: [
|
||||
cache.get(model_id, modality, item, hf_processor_mm_kwargs)
|
||||
for item in items
|
||||
@ -855,8 +853,9 @@ class BaseMultiModalProcessor(ABC):
|
||||
}
|
||||
|
||||
mm_missing_idxs = {
|
||||
modality: [idx for idx, out in enumerate(fields) if out is None]
|
||||
for modality, fields in mm_maybe_cached_field_items.items()
|
||||
modality:
|
||||
[idx for idx, item in enumerate(kw_items) if item is None]
|
||||
for modality, kw_items in mm_maybe_cached_kw_items.items()
|
||||
}
|
||||
mm_missing_data = {
|
||||
modality: [mm_data_items[modality][idx] for idx in idxs]
|
||||
@ -875,14 +874,11 @@ class BaseMultiModalProcessor(ABC):
|
||||
for modality in mm_missing_data_items
|
||||
}
|
||||
|
||||
mm_merged_field_items = dict[str, list[Mapping[str,
|
||||
MultiModalFieldItem]]]()
|
||||
for modality, modal_items_lst in mm_maybe_cached_field_items.items():
|
||||
merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]()
|
||||
|
||||
for idx, modal_items in enumerate(modal_items_lst):
|
||||
if modal_items is None:
|
||||
modal_items = mm_missing_kwargs.get_items_by_modality(
|
||||
merged_kw_items = list[MultiModalKwargsItem]()
|
||||
for modality, kw_items in mm_maybe_cached_kw_items.items():
|
||||
for idx, kw_item in enumerate(kw_items):
|
||||
if kw_item is None:
|
||||
kw_item = mm_missing_kwargs.get_item(
|
||||
modality,
|
||||
mm_missing_next_idx[modality],
|
||||
)
|
||||
@ -892,14 +888,12 @@ class BaseMultiModalProcessor(ABC):
|
||||
modality,
|
||||
mm_data_items[modality][idx],
|
||||
hf_processor_mm_kwargs,
|
||||
modal_items,
|
||||
kw_item,
|
||||
)
|
||||
|
||||
mm_missing_next_idx[modality] += 1
|
||||
|
||||
merged_modal_items_lst.append(modal_items)
|
||||
|
||||
mm_merged_field_items[modality] = merged_modal_items_lst
|
||||
merged_kw_items.append(kw_item)
|
||||
|
||||
if self.enable_sanity_checks:
|
||||
mm_missing_counts = mm_missing_data_items.get_all_counts()
|
||||
@ -909,10 +903,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
mm_missing_next_idx=mm_missing_next_idx,
|
||||
mm_missing_counts=mm_missing_counts)
|
||||
|
||||
mm_kwargs = MultiModalKwargs.from_items_by_modality(
|
||||
mm_merged_field_items,
|
||||
enable_sanity_checks=self.enable_sanity_checks,
|
||||
)
|
||||
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
|
||||
|
||||
if self.enable_sanity_checks:
|
||||
mm_item_counts = mm_data_items.get_all_counts()
|
||||
@ -920,7 +911,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
for modality, item_count in mm_item_counts.items():
|
||||
for item_idx in range(item_count):
|
||||
try:
|
||||
mm_kwargs.get_items_by_modality(modality, item_idx)
|
||||
mm_kwargs.get_item(modality, item_idx)
|
||||
except Exception as e:
|
||||
# Make it easy to set a breakpoint in the debugger
|
||||
raise e
|
||||
|
@ -113,15 +113,27 @@ class Processor:
|
||||
|
||||
# For merged preprocessor, mm_data is already mm_inputs
|
||||
precomputed_mm_inputs = None
|
||||
if isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
|
||||
precomputed_mm_inputs = [decoder_inputs.multi_modal_data]
|
||||
decoder_mm_data = decoder_inputs.multi_modal_data
|
||||
if isinstance(decoder_mm_data, MultiModalKwargs):
|
||||
# The output of merged multi-modal processor (`decoder_mm_data`)
|
||||
# contains the kwargs for all items from all modalities.
|
||||
# This code separates them so that there is one set of kwargs
|
||||
# per item per modality.
|
||||
precomputed_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item])
|
||||
for modality in decoder_mm_data.modalities
|
||||
for item in decoder_mm_data.get_items(modality)
|
||||
]
|
||||
|
||||
# Apply MM mapper
|
||||
mm_inputs = None
|
||||
if len(decoder_inputs.multi_modal_data) > 0:
|
||||
if len(decoder_mm_data) > 0:
|
||||
mm_inputs = self.mm_input_mapper_client.process_inputs(
|
||||
decoder_inputs.multi_modal_data, mm_hashes,
|
||||
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
|
||||
decoder_mm_data,
|
||||
mm_hashes,
|
||||
decoder_inputs.mm_processor_kwargs,
|
||||
precomputed_mm_inputs,
|
||||
)
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id,
|
||||
|
Loading…
x
Reference in New Issue
Block a user