[Doc] Guide for adding multi-modal plugins (#6205)

This commit is contained in:
Cyrus Leung 2024-07-10 14:55:34 +08:00 committed by GitHub
parent 5ed3505d82
commit 8a924d2248
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 64 additions and 23 deletions

View File

@ -5,6 +5,7 @@
justify-content: center; justify-content: center;
align-items: center; align-items: center;
font-size: 16px; font-size: 16px;
padding: 0 6px 0 6px;
} }
.notification-bar p { .notification-bar p {
margin: 0; margin: 0;

View File

@ -0,0 +1,17 @@
.. _adding_multimodal_plugin:
Adding a Multimodal Plugin
==========================
This document teaches you how to add a new modality to vLLM.
Each modality in vLLM is represented by a :class:`~vllm.multimodal.MultiModalPlugin` and registered to :data:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
For vLLM to recognize a new modality type, you have to create a new plugin and then pass it to :meth:`~vllm.multimodal.MultiModalRegistry.register_plugin`.
The remainder of this document details how to define custom :class:`~vllm.multimodal.MultiModalPlugin` s.
.. note::
This article is a work in progress.
..
TODO: Add more instructions on how to add new plugins once embeddings is in.

View File

@ -7,17 +7,21 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal input can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>` Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`. via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
.. note:: Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through by following :ref:`this guide <adding_multimodal_plugin>`.
the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
To implement a new multi-modal model in vLLM, please follow :ref:`this guide <enabling_multimodal_inputs>`. Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
.. Guides
TODO: Add more instructions on how to add new plugins once embeddings is in. ++++++
.. toctree::
:maxdepth: 1
adding_multimodal_plugin
Module Contents Module Contents
+++++++++++++++ +++++++++++++++
@ -36,10 +40,14 @@ Registry
Base Classes Base Classes
------------ ------------
.. autoclass:: vllm.multimodal.MultiModalDataDict .. autodata:: vllm.multimodal.BatchedTensors
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
:members: :members:
:show-inheritance: :show-inheritance:
.. autodata:: vllm.multimodal.MultiModalDataDict
.. autoclass:: vllm.multimodal.MultiModalInputs .. autoclass:: vllm.multimodal.MultiModalInputs
:members: :members:
:show-inheritance: :show-inheritance:

View File

@ -1,5 +1,5 @@
from .base import (BatchedTensors, MultiModalDataDict, MultiModalInputs, from .base import (BatchedTensors, MultiModalDataBuiltins, MultiModalDataDict,
MultiModalPlugin) MultiModalInputs, MultiModalPlugin)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry() MULTIMODAL_REGISTRY = MultiModalRegistry()
@ -13,6 +13,7 @@ See also:
__all__ = [ __all__ = [
"BatchedTensors", "BatchedTensors",
"MultiModalDataBuiltins",
"MultiModalDataDict", "MultiModalDataDict",
"MultiModalInputs", "MultiModalInputs",
"MultiModalPlugin", "MultiModalPlugin",

View File

@ -43,9 +43,6 @@ class MultiModalInputs(_MultiModalInputsBase):
*, *,
device: torch.types.Device, device: torch.types.Device,
) -> BatchedTensors: ) -> BatchedTensors:
# Avoid initializing CUDA too early
import torch
unbatched_shape = tensors[0].shape[1:] unbatched_shape = tensors[0].shape[1:]
for tensor in tensors: for tensor in tensors:
@ -84,16 +81,21 @@ class MultiModalInputs(_MultiModalInputsBase):
class MultiModalDataBuiltins(TypedDict, total=False): class MultiModalDataBuiltins(TypedDict, total=False):
"""Modality types that are predefined by vLLM."""
image: Image.Image image: Image.Image
"""The input image."""
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
""" """
A dictionary containing an item for each modality type to input. A dictionary containing an item for each modality type to input.
The data belonging to each modality is converted into keyword arguments Note:
to the model by the corresponding mapper. By default, the mapper of This dictionary also accepts modality keys defined outside
the corresponding plugin with the same modality key is applied. :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
""" """
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs] MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
@ -123,6 +125,9 @@ class MultiModalPlugin(ABC):
process the same data differently). This registry is in turn used by process the same data differently). This registry is in turn used by
:class:`~MultiModalRegistry` which acts at a higher level :class:`~MultiModalRegistry` which acts at a higher level
(i.e., the modality of the data). (i.e., the modality of the data).
See also:
:ref:`adding_multimodal_plugin`
""" """
def __init__(self) -> None: def __init__(self) -> None:
@ -183,8 +188,8 @@ class MultiModalPlugin(ABC):
def map_input(self, model_config: ModelConfig, def map_input(self, model_config: ModelConfig,
data: object) -> MultiModalInputs: data: object) -> MultiModalInputs:
""" """
Apply an input mapper to a data passed Transform the data into a dictionary of model inputs using the
to the model, transforming the data into a dictionary of model inputs. input mapper registered for that model.
The model is identified by ``model_config``. The model is identified by ``model_config``.

View File

@ -100,6 +100,7 @@ def repeat_and_pad_image_tokens(
class ImagePlugin(MultiModalPlugin): class ImagePlugin(MultiModalPlugin):
"""Plugin for image data."""
def get_data_key(self) -> str: def get_data_key(self) -> str:
return "image" return "image"

View File

@ -15,10 +15,8 @@ logger = init_logger(__name__)
class MultiModalRegistry: class MultiModalRegistry:
""" """
A registry to dispatch data processing A registry that dispatches data processing to the
according to its modality and the target model. :class:`~vllm.multimodal.MultiModalPlugin` for each modality.
The registry handles both external and internal data input.
""" """
DEFAULT_PLUGINS = (ImagePlugin(), ) DEFAULT_PLUGINS = (ImagePlugin(), )
@ -30,6 +28,12 @@ class MultiModalRegistry:
self._plugins = {p.get_data_key(): p for p in plugins} self._plugins = {p.get_data_key(): p for p in plugins}
def register_plugin(self, plugin: MultiModalPlugin) -> None: def register_plugin(self, plugin: MultiModalPlugin) -> None:
"""
Register a multi-modal plugin so it can be recognized by vLLM.
See also:
:ref:`adding_multimodal_plugin`
"""
data_type_key = plugin.get_data_key() data_type_key = plugin.get_data_key()
if data_type_key in self._plugins: if data_type_key in self._plugins:
@ -75,7 +79,11 @@ class MultiModalRegistry:
data: MultiModalDataDict) -> MultiModalInputs: data: MultiModalDataDict) -> MultiModalInputs:
""" """
Apply an input mapper to the data passed to the model. Apply an input mapper to the data passed to the model.
The data belonging to each modality is passed to the corresponding
plugin which in turn converts the data into into keyword arguments
via the input mapper registered for that model.
See :meth:`MultiModalPlugin.map_input` for more details. See :meth:`MultiModalPlugin.map_input` for more details.
""" """
merged_dict: Dict[str, torch.Tensor] = {} merged_dict: Dict[str, torch.Tensor] = {}