[Doc][V1] Update model implementation guide for V1 support (#11998)
Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
0f8cafe2d1
commit
cd8249903f
@ -57,7 +57,17 @@ class MyModelForCausalLM(nn.Module):
|
||||
|
||||
### Computation Code
|
||||
|
||||
Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
|
||||
- Add a `get_input_embeddings` method inside `MyModel` module that returns the text embeddings given `input_ids`. This is equivalent to directly calling the text embedding layer, but provides a unified interface in case `MyModel` is used within a composite multimodal model.
|
||||
|
||||
```python
|
||||
class MyModel(nn.Module):
|
||||
...
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
|
||||
- Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
|
||||
|
||||
```python
|
||||
def forward(
|
||||
|
@ -9,7 +9,78 @@ This document walks you through the steps to extend a basic model so that it acc
|
||||
It is assumed that you have already implemented the model in vLLM according to [these steps](#new-model-basic).
|
||||
Further update the model as follows:
|
||||
|
||||
- Implement the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
|
||||
- Reserve a keyword parameter in {meth}`~torch.nn.Module.forward` for each input tensor that corresponds to a multi-modal input, as shown in the following example:
|
||||
|
||||
```diff
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
+ pixel_values: torch.Tensor,
|
||||
) -> SamplerOutput:
|
||||
```
|
||||
|
||||
More conveniently, you can simply pass `**kwargs` to the {meth}`~torch.nn.Module.forward` method and retrieve the keyword parameters for multimodal inputs from it.
|
||||
|
||||
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings` that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
|
||||
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
|
||||
|
||||
assert self.vision_encoder is not None
|
||||
image_features = self.vision_encoder(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]:
|
||||
|
||||
# Validate the multimodal input keyword arguments
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
# Run multimodal inputs through encoder and projector
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
```
|
||||
|
||||
```{important}
|
||||
The returned `multimodal_embeddings` must be either a **3D {class}`torch.Tensor`** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D {class}`torch.Tensor`'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
|
||||
```
|
||||
|
||||
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings` to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
|
||||
|
||||
```python
|
||||
from .utils import merge_multimodal_embeddings
|
||||
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# `get_input_embeddings` should already be implemented for the language
|
||||
# model as one of the requirements of basic vLLM model implementation.
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_index)
|
||||
|
||||
return inputs_embeds
|
||||
```
|
||||
|
||||
- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
|
||||
|
||||
```diff
|
||||
+ from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
@ -23,20 +94,6 @@ Further update the model as follows:
|
||||
Check out [the HuggingFace Transformers documentation](https://huggingface.co/docs/transformers/model_doc/auto#multimodal) for some examples.
|
||||
```
|
||||
|
||||
- If you haven't already done so, reserve a keyword parameter in {meth}`~torch.nn.Module.forward`
|
||||
for each input tensor that corresponds to a multi-modal input, as shown in the following example:
|
||||
|
||||
```diff
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
+ pixel_values: torch.Tensor,
|
||||
) -> SamplerOutput:
|
||||
```
|
||||
|
||||
## 2. Specify processing information
|
||||
|
||||
Next, create a subclass of {class}`~vllm.multimodal.processing.BaseProcessingInfo`
|
||||
|
Loading…
x
Reference in New Issue
Block a user