From 8a924d2248dedb620eb9a32ca5c9f97ab525aaf5 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 10 Jul 2024 14:55:34 +0800 Subject: [PATCH] [Doc] Guide for adding multi-modal plugins (#6205) --- docs/source/_templates/sections/header.html | 1 + .../multimodal/adding_multimodal_plugin.rst | 17 +++++++++++++ .../dev/multimodal/multimodal_index.rst | 24 ++++++++++++------- vllm/multimodal/__init__.py | 5 ++-- vllm/multimodal/base.py | 21 +++++++++------- vllm/multimodal/image.py | 1 + vllm/multimodal/registry.py | 18 ++++++++++---- 7 files changed, 64 insertions(+), 23 deletions(-) create mode 100644 docs/source/dev/multimodal/adding_multimodal_plugin.rst diff --git a/docs/source/_templates/sections/header.html b/docs/source/_templates/sections/header.html index cd5c4053..7174431b 100644 --- a/docs/source/_templates/sections/header.html +++ b/docs/source/_templates/sections/header.html @@ -5,6 +5,7 @@ justify-content: center; align-items: center; font-size: 16px; + padding: 0 6px 0 6px; } .notification-bar p { margin: 0; diff --git a/docs/source/dev/multimodal/adding_multimodal_plugin.rst b/docs/source/dev/multimodal/adding_multimodal_plugin.rst new file mode 100644 index 00000000..b726138f --- /dev/null +++ b/docs/source/dev/multimodal/adding_multimodal_plugin.rst @@ -0,0 +1,17 @@ +.. _adding_multimodal_plugin: + +Adding a Multimodal Plugin +========================== + +This document teaches you how to add a new modality to vLLM. + +Each modality in vLLM is represented by a :class:`~vllm.multimodal.MultiModalPlugin` and registered to :data:`~vllm.multimodal.MULTIMODAL_REGISTRY`. +For vLLM to recognize a new modality type, you have to create a new plugin and then pass it to :meth:`~vllm.multimodal.MultiModalRegistry.register_plugin`. + +The remainder of this document details how to define custom :class:`~vllm.multimodal.MultiModalPlugin` s. + +.. note:: + This article is a work in progress. + +.. + TODO: Add more instructions on how to add new plugins once embeddings is in. diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 39daf30a..6713dcf0 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -7,17 +7,21 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. -Multi-modal input can be passed alongside text and token prompts to :ref:`supported models ` +Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`. -.. note:: - ``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through - the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. +Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities +by following :ref:`this guide `. -To implement a new multi-modal model in vLLM, please follow :ref:`this guide `. +Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here `. -.. - TODO: Add more instructions on how to add new plugins once embeddings is in. +Guides +++++++ + +.. toctree:: + :maxdepth: 1 + + adding_multimodal_plugin Module Contents +++++++++++++++ @@ -36,10 +40,14 @@ Registry Base Classes ------------ -.. autoclass:: vllm.multimodal.MultiModalDataDict +.. autodata:: vllm.multimodal.BatchedTensors + +.. autoclass:: vllm.multimodal.MultiModalDataBuiltins :members: :show-inheritance: +.. autodata:: vllm.multimodal.MultiModalDataDict + .. autoclass:: vllm.multimodal.MultiModalInputs :members: :show-inheritance: diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index b6d93065..503dceab 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,5 +1,5 @@ -from .base import (BatchedTensors, MultiModalDataDict, MultiModalInputs, - MultiModalPlugin) +from .base import (BatchedTensors, MultiModalDataBuiltins, MultiModalDataDict, + MultiModalInputs, MultiModalPlugin) from .registry import MultiModalRegistry MULTIMODAL_REGISTRY = MultiModalRegistry() @@ -13,6 +13,7 @@ See also: __all__ = [ "BatchedTensors", + "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalInputs", "MultiModalPlugin", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 0e31816a..3ebc25c5 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -43,9 +43,6 @@ class MultiModalInputs(_MultiModalInputsBase): *, device: torch.types.Device, ) -> BatchedTensors: - # Avoid initializing CUDA too early - import torch - unbatched_shape = tensors[0].shape[1:] for tensor in tensors: @@ -84,16 +81,21 @@ class MultiModalInputs(_MultiModalInputsBase): class MultiModalDataBuiltins(TypedDict, total=False): + """Modality types that are predefined by vLLM.""" + image: Image.Image + """The input image.""" MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]] """ A dictionary containing an item for each modality type to input. -The data belonging to each modality is converted into keyword arguments -to the model by the corresponding mapper. By default, the mapper of -the corresponding plugin with the same modality key is applied. +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalDataBuiltins` as long as a customized plugin is registered + through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. """ MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs] @@ -123,6 +125,9 @@ class MultiModalPlugin(ABC): process the same data differently). This registry is in turn used by :class:`~MultiModalRegistry` which acts at a higher level (i.e., the modality of the data). + + See also: + :ref:`adding_multimodal_plugin` """ def __init__(self) -> None: @@ -183,8 +188,8 @@ class MultiModalPlugin(ABC): def map_input(self, model_config: ModelConfig, data: object) -> MultiModalInputs: """ - Apply an input mapper to a data passed - to the model, transforming the data into a dictionary of model inputs. + Transform the data into a dictionary of model inputs using the + input mapper registered for that model. The model is identified by ``model_config``. diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index b6c73512..3b37ce91 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -100,6 +100,7 @@ def repeat_and_pad_image_tokens( class ImagePlugin(MultiModalPlugin): + """Plugin for image data.""" def get_data_key(self) -> str: return "image" diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index e0716bbf..d8e1b681 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -15,10 +15,8 @@ logger = init_logger(__name__) class MultiModalRegistry: """ - A registry to dispatch data processing - according to its modality and the target model. - - The registry handles both external and internal data input. + A registry that dispatches data processing to the + :class:`~vllm.multimodal.MultiModalPlugin` for each modality. """ DEFAULT_PLUGINS = (ImagePlugin(), ) @@ -30,6 +28,12 @@ class MultiModalRegistry: self._plugins = {p.get_data_key(): p for p in plugins} def register_plugin(self, plugin: MultiModalPlugin) -> None: + """ + Register a multi-modal plugin so it can be recognized by vLLM. + + See also: + :ref:`adding_multimodal_plugin` + """ data_type_key = plugin.get_data_key() if data_type_key in self._plugins: @@ -75,7 +79,11 @@ class MultiModalRegistry: data: MultiModalDataDict) -> MultiModalInputs: """ Apply an input mapper to the data passed to the model. - + + The data belonging to each modality is passed to the corresponding + plugin which in turn converts the data into into keyword arguments + via the input mapper registered for that model. + See :meth:`MultiModalPlugin.map_input` for more details. """ merged_dict: Dict[str, torch.Tensor] = {}