vllm/docs/source/models/enabling_multimodal_inputs.md
Rafael Vasquez 32aa2059ad
[Docs] Convert rST to MyST (Markdown) (#11145)
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
2024-12-23 22:35:38 +00:00

6.1 KiB

(enabling-multimodal-inputs)=

Enabling Multimodal Inputs

This document walks you through the steps to extend a vLLM model so that it accepts multi-modal inputs.

[Adding a New Model](adding-a-new-model)

1. Update the base vLLM model

It is assumed that you have already implemented the model in vLLM according to these steps. Further update the model as follows:

  • Implement the {class}~vllm.model_executor.models.interfaces.SupportsMultiModal interface.

    + from vllm.model_executor.models.interfaces import SupportsMultiModal
    
    - class YourModelForImage2Seq(nn.Module):
    + class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
    
    The model class does not have to be named {code}`*ForCausalLM`.
    Check out [the HuggingFace Transformers documentation](https://huggingface.co/docs/transformers/model_doc/auto#multimodal) for some examples.
    
  • If you haven't already done so, reserve a keyword parameter in {meth}~torch.nn.Module.forward for each input tensor that corresponds to a multi-modal input, as shown in the following example:

      def forward(
          self,
          input_ids: torch.Tensor,
          positions: torch.Tensor,
          kv_caches: List[torch.Tensor],
          attn_metadata: AttentionMetadata,
    +     pixel_values: torch.Tensor,
      ) -> SamplerOutput:
    

2. Register input mappers

For each modality type that the model accepts as input, decorate the model class with {meth}MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>. This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in {meth}~torch.nn.Module.forward.

  from vllm.model_executor.models.interfaces import SupportsMultiModal
+ from vllm.multimodal import MULTIMODAL_REGISTRY

+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
  class YourModelForImage2Seq(nn.Module, SupportsMultiModal):

A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.

[Input Processing Pipeline](#input-processing-pipeline)

3. Register maximum number of multi-modal tokens

For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item and register it via {meth}INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>.

  from vllm.inputs import INPUT_REGISTRY
  from vllm.model_executor.models.interfaces import SupportsMultiModal
  from vllm.multimodal import MULTIMODAL_REGISTRY

  @MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
  @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
  class YourModelForImage2Seq(nn.Module, SupportsMultiModal):

Here are some examples:

[Input Processing Pipeline](#input-processing-pipeline)

4. (Optional) Register dummy data

During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models. In such cases, you can define your own dummy data by registering a factory method via {meth}INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>.

  from vllm.inputs import INPUT_REGISTRY
  from vllm.model_executor.models.interfaces import SupportsMultiModal
  from vllm.multimodal import MULTIMODAL_REGISTRY

  @MULTIMODAL_REGISTRY.register_image_input_mapper()
  @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
  class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.

Here are some examples:

[Input Processing Pipeline](#input-processing-pipeline)

5. (Optional) Register input processor

Sometimes, there is a need to process inputs at the {class}~vllm.LLMEngine level before they are passed to the model executor. This is often due to the fact that unlike implementations in HuggingFace Transformers, the reshaping and/or expansion of multi-modal embeddings needs to take place outside model's {meth}~torch.nn.Module.forward call. You can register input processors via {meth}INPUT_REGISTRY.register_input_processor <vllm.inputs.registry.InputRegistry.register_input_processor>.

  from vllm.inputs import INPUT_REGISTRY
  from vllm.model_executor.models.interfaces import SupportsMultiModal
  from vllm.multimodal import MULTIMODAL_REGISTRY

  @MULTIMODAL_REGISTRY.register_image_input_mapper()
  @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
  @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
  class YourModelForImage2Seq(nn.Module, SupportsMultiModal):

A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation. Here are some examples:

[Input Processing Pipeline](#input-processing-pipeline)