[Doc] Guide for adding multi-modal plugins (#6205)
This commit is contained in:
parent
5ed3505d82
commit
8a924d2248
@ -5,6 +5,7 @@
|
|||||||
justify-content: center;
|
justify-content: center;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
font-size: 16px;
|
font-size: 16px;
|
||||||
|
padding: 0 6px 0 6px;
|
||||||
}
|
}
|
||||||
.notification-bar p {
|
.notification-bar p {
|
||||||
margin: 0;
|
margin: 0;
|
||||||
|
17
docs/source/dev/multimodal/adding_multimodal_plugin.rst
Normal file
17
docs/source/dev/multimodal/adding_multimodal_plugin.rst
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
.. _adding_multimodal_plugin:
|
||||||
|
|
||||||
|
Adding a Multimodal Plugin
|
||||||
|
==========================
|
||||||
|
|
||||||
|
This document teaches you how to add a new modality to vLLM.
|
||||||
|
|
||||||
|
Each modality in vLLM is represented by a :class:`~vllm.multimodal.MultiModalPlugin` and registered to :data:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||||
|
For vLLM to recognize a new modality type, you have to create a new plugin and then pass it to :meth:`~vllm.multimodal.MultiModalRegistry.register_plugin`.
|
||||||
|
|
||||||
|
The remainder of this document details how to define custom :class:`~vllm.multimodal.MultiModalPlugin` s.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This article is a work in progress.
|
||||||
|
|
||||||
|
..
|
||||||
|
TODO: Add more instructions on how to add new plugins once embeddings is in.
|
@ -7,17 +7,21 @@ Multi-Modality
|
|||||||
|
|
||||||
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||||
|
|
||||||
Multi-modal input can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
||||||
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
|
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
|
||||||
|
|
||||||
.. note::
|
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
||||||
``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through
|
by following :ref:`this guide <adding_multimodal_plugin>`.
|
||||||
the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
|
||||||
|
|
||||||
To implement a new multi-modal model in vLLM, please follow :ref:`this guide <enabling_multimodal_inputs>`.
|
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
|
||||||
|
|
||||||
..
|
Guides
|
||||||
TODO: Add more instructions on how to add new plugins once embeddings is in.
|
++++++
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
adding_multimodal_plugin
|
||||||
|
|
||||||
Module Contents
|
Module Contents
|
||||||
+++++++++++++++
|
+++++++++++++++
|
||||||
@ -36,10 +40,14 @@ Registry
|
|||||||
Base Classes
|
Base Classes
|
||||||
------------
|
------------
|
||||||
|
|
||||||
.. autoclass:: vllm.multimodal.MultiModalDataDict
|
.. autodata:: vllm.multimodal.BatchedTensors
|
||||||
|
|
||||||
|
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
|
||||||
:members:
|
:members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autodata:: vllm.multimodal.MultiModalDataDict
|
||||||
|
|
||||||
.. autoclass:: vllm.multimodal.MultiModalInputs
|
.. autoclass:: vllm.multimodal.MultiModalInputs
|
||||||
:members:
|
:members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .base import (BatchedTensors, MultiModalDataDict, MultiModalInputs,
|
from .base import (BatchedTensors, MultiModalDataBuiltins, MultiModalDataDict,
|
||||||
MultiModalPlugin)
|
MultiModalInputs, MultiModalPlugin)
|
||||||
from .registry import MultiModalRegistry
|
from .registry import MultiModalRegistry
|
||||||
|
|
||||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||||
@ -13,6 +13,7 @@ See also:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatchedTensors",
|
"BatchedTensors",
|
||||||
|
"MultiModalDataBuiltins",
|
||||||
"MultiModalDataDict",
|
"MultiModalDataDict",
|
||||||
"MultiModalInputs",
|
"MultiModalInputs",
|
||||||
"MultiModalPlugin",
|
"MultiModalPlugin",
|
||||||
|
@ -43,9 +43,6 @@ class MultiModalInputs(_MultiModalInputsBase):
|
|||||||
*,
|
*,
|
||||||
device: torch.types.Device,
|
device: torch.types.Device,
|
||||||
) -> BatchedTensors:
|
) -> BatchedTensors:
|
||||||
# Avoid initializing CUDA too early
|
|
||||||
import torch
|
|
||||||
|
|
||||||
unbatched_shape = tensors[0].shape[1:]
|
unbatched_shape = tensors[0].shape[1:]
|
||||||
|
|
||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
@ -84,16 +81,21 @@ class MultiModalInputs(_MultiModalInputsBase):
|
|||||||
|
|
||||||
|
|
||||||
class MultiModalDataBuiltins(TypedDict, total=False):
|
class MultiModalDataBuiltins(TypedDict, total=False):
|
||||||
|
"""Modality types that are predefined by vLLM."""
|
||||||
|
|
||||||
image: Image.Image
|
image: Image.Image
|
||||||
|
"""The input image."""
|
||||||
|
|
||||||
|
|
||||||
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
|
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
|
||||||
"""
|
"""
|
||||||
A dictionary containing an item for each modality type to input.
|
A dictionary containing an item for each modality type to input.
|
||||||
|
|
||||||
The data belonging to each modality is converted into keyword arguments
|
Note:
|
||||||
to the model by the corresponding mapper. By default, the mapper of
|
This dictionary also accepts modality keys defined outside
|
||||||
the corresponding plugin with the same modality key is applied.
|
:class:`MultiModalDataBuiltins` as long as a customized plugin is registered
|
||||||
|
through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||||
|
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
|
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
|
||||||
@ -123,6 +125,9 @@ class MultiModalPlugin(ABC):
|
|||||||
process the same data differently). This registry is in turn used by
|
process the same data differently). This registry is in turn used by
|
||||||
:class:`~MultiModalRegistry` which acts at a higher level
|
:class:`~MultiModalRegistry` which acts at a higher level
|
||||||
(i.e., the modality of the data).
|
(i.e., the modality of the data).
|
||||||
|
|
||||||
|
See also:
|
||||||
|
:ref:`adding_multimodal_plugin`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -183,8 +188,8 @@ class MultiModalPlugin(ABC):
|
|||||||
def map_input(self, model_config: ModelConfig,
|
def map_input(self, model_config: ModelConfig,
|
||||||
data: object) -> MultiModalInputs:
|
data: object) -> MultiModalInputs:
|
||||||
"""
|
"""
|
||||||
Apply an input mapper to a data passed
|
Transform the data into a dictionary of model inputs using the
|
||||||
to the model, transforming the data into a dictionary of model inputs.
|
input mapper registered for that model.
|
||||||
|
|
||||||
The model is identified by ``model_config``.
|
The model is identified by ``model_config``.
|
||||||
|
|
||||||
|
@ -100,6 +100,7 @@ def repeat_and_pad_image_tokens(
|
|||||||
|
|
||||||
|
|
||||||
class ImagePlugin(MultiModalPlugin):
|
class ImagePlugin(MultiModalPlugin):
|
||||||
|
"""Plugin for image data."""
|
||||||
|
|
||||||
def get_data_key(self) -> str:
|
def get_data_key(self) -> str:
|
||||||
return "image"
|
return "image"
|
||||||
|
@ -15,10 +15,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class MultiModalRegistry:
|
class MultiModalRegistry:
|
||||||
"""
|
"""
|
||||||
A registry to dispatch data processing
|
A registry that dispatches data processing to the
|
||||||
according to its modality and the target model.
|
:class:`~vllm.multimodal.MultiModalPlugin` for each modality.
|
||||||
|
|
||||||
The registry handles both external and internal data input.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_PLUGINS = (ImagePlugin(), )
|
DEFAULT_PLUGINS = (ImagePlugin(), )
|
||||||
@ -30,6 +28,12 @@ class MultiModalRegistry:
|
|||||||
self._plugins = {p.get_data_key(): p for p in plugins}
|
self._plugins = {p.get_data_key(): p for p in plugins}
|
||||||
|
|
||||||
def register_plugin(self, plugin: MultiModalPlugin) -> None:
|
def register_plugin(self, plugin: MultiModalPlugin) -> None:
|
||||||
|
"""
|
||||||
|
Register a multi-modal plugin so it can be recognized by vLLM.
|
||||||
|
|
||||||
|
See also:
|
||||||
|
:ref:`adding_multimodal_plugin`
|
||||||
|
"""
|
||||||
data_type_key = plugin.get_data_key()
|
data_type_key = plugin.get_data_key()
|
||||||
|
|
||||||
if data_type_key in self._plugins:
|
if data_type_key in self._plugins:
|
||||||
@ -75,7 +79,11 @@ class MultiModalRegistry:
|
|||||||
data: MultiModalDataDict) -> MultiModalInputs:
|
data: MultiModalDataDict) -> MultiModalInputs:
|
||||||
"""
|
"""
|
||||||
Apply an input mapper to the data passed to the model.
|
Apply an input mapper to the data passed to the model.
|
||||||
|
|
||||||
|
The data belonging to each modality is passed to the corresponding
|
||||||
|
plugin which in turn converts the data into into keyword arguments
|
||||||
|
via the input mapper registered for that model.
|
||||||
|
|
||||||
See :meth:`MultiModalPlugin.map_input` for more details.
|
See :meth:`MultiModalPlugin.map_input` for more details.
|
||||||
"""
|
"""
|
||||||
merged_dict: Dict[str, torch.Tensor] = {}
|
merged_dict: Dict[str, torch.Tensor] = {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user