
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: ywang96 <ywang@roblox.com> Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
123 lines
4.1 KiB
Python
123 lines
4.1 KiB
Python
import functools
|
|
from typing import Dict, Optional, Sequence
|
|
|
|
import torch
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.logger import init_logger
|
|
|
|
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
|
|
MultiModalPlugin)
|
|
from .image import ImagePlugin
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class MultiModalRegistry:
|
|
"""
|
|
A registry to dispatch data processing
|
|
according to its modality and the target model.
|
|
|
|
The registry handles both external and internal data input.
|
|
"""
|
|
|
|
DEFAULT_PLUGINS = (ImagePlugin(), )
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
|
|
self._plugins = {p.get_data_key(): p for p in plugins}
|
|
|
|
def register_plugin(self, plugin: MultiModalPlugin) -> None:
|
|
data_type_key = plugin.get_data_key()
|
|
|
|
if data_type_key in self._plugins:
|
|
logger.warning(
|
|
"A plugin is already registered for data type %s, "
|
|
"and will be overwritten by the new plugin %s.", data_type_key,
|
|
plugin)
|
|
|
|
self._plugins[data_type_key] = plugin
|
|
|
|
def _get_plugin(self, data_type_key: str):
|
|
plugin = self._plugins.get(data_type_key)
|
|
if plugin is not None:
|
|
return plugin
|
|
|
|
msg = f"Unknown multi-modal data type: {data_type_key}"
|
|
raise NotImplementedError(msg)
|
|
|
|
def register_image_input_mapper(
|
|
self,
|
|
mapper: Optional[MultiModalInputMapper] = None,
|
|
):
|
|
"""
|
|
Register an input mapper for image data to a model class.
|
|
|
|
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
|
"""
|
|
return self.register_input_mapper("image", mapper)
|
|
|
|
def _process_input(self, key: str, value: object,
|
|
model_config: ModelConfig) -> MultiModalInputs:
|
|
plugin = self._plugins.get(key)
|
|
if plugin:
|
|
return plugin.map_input(model_config, value)
|
|
msg = f"Unknown multi-modal data type: {key}"
|
|
raise NotImplementedError(msg)
|
|
|
|
def register_input_mapper(
|
|
self,
|
|
data_type: str,
|
|
mapper: Optional[MultiModalInputMapper] = None,
|
|
):
|
|
"""
|
|
Register an input mapper for a specific modality to a model class.
|
|
|
|
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
|
"""
|
|
plugin = self._plugins.get(data_type)
|
|
if not plugin:
|
|
msg = f"Unknown multi-modal data type: {data_type}"
|
|
raise NotImplementedError(msg)
|
|
return plugin.register_input_mapper(mapper)
|
|
|
|
def register_image_input(self,
|
|
mapper: Optional[MultiModalInputMapper] = None):
|
|
"""
|
|
Register an input mapper for image pixel data to a model class.
|
|
|
|
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
|
|
"""
|
|
return self.register_input_mapper("image", mapper)
|
|
|
|
def map_input(self, model_config: ModelConfig,
|
|
data: MultiModalDataDict) -> MultiModalInputs:
|
|
"""
|
|
Apply an input mapper to the data passed to the model.
|
|
|
|
See :meth:`MultiModalPlugin.map_input` for more details.
|
|
"""
|
|
merged_dict: Dict[str, torch.Tensor] = {}
|
|
|
|
for data_key, data_value in data.items():
|
|
input_dict = self._process_input(data_key, data_value,
|
|
model_config)
|
|
|
|
for input_key, input_tensor in input_dict.items():
|
|
if input_key in merged_dict:
|
|
raise ValueError(f"The input mappers (keys={set(data)}) "
|
|
f"resulted in a conflicting keyword "
|
|
f"argument to `forward()`: {input_key}")
|
|
|
|
merged_dict[input_key] = input_tensor
|
|
|
|
return MultiModalInputs(merged_dict)
|
|
|
|
def create_input_mapper(self, model_config: ModelConfig):
|
|
"""
|
|
Create an input mapper (see :meth:`map_input`) for a specific model.
|
|
"""
|
|
return functools.partial(self.map_input, model_config)
|