[Doc] Move multimodal Embedding API example to Online Serving page (#14017)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-02-28 15:12:04 +08:00 committed by GitHub
parent 73e0225ee9
commit 1088f06242
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 89 additions and 84 deletions

View File

@ -16,7 +16,7 @@ To input multi-modal data, follow this schema in {class}`vllm.inputs.PromptType`
- `prompt`: The prompt should follow the format that is documented on HuggingFace.
- `multi_modal_data`: This is a dictionary that follows the schema defined in {class}`vllm.multimodal.inputs.MultiModalDataDict`.
### Image
### Image Inputs
You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples:
@ -120,20 +120,20 @@ for o in outputs:
print(generated_text)
```
### Video
### Video Inputs
You can pass a list of NumPy arrays directly to the `'video'` field of the multi-modal dictionary
instead of using multi-image input.
Full example: <gh-file:examples/offline_inference/vision_language.py>
### Audio
### Audio Inputs
You can pass a tuple `(array, sampling_rate)` to the `'audio'` field of the multi-modal dictionary.
Full example: <gh-file:examples/offline_inference/audio_language.py>
### Embedding
### Embedding Inputs
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
@ -211,7 +211,7 @@ The chat template can be inferred based on the documentation on the model's Hugg
For example, LLaVA-1.5 (`llava-hf/llava-1.5-7b-hf`) requires a chat template that can be found here: <gh-file:examples/template_llava.jinja>
:::
### Image
### Image Inputs
Image input is supported according to [OpenAI Vision API](https://platform.openai.com/docs/guides/vision).
Here is a simple example using Phi-3.5-Vision.
@ -293,7 +293,7 @@ export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
:::
### Video
### Video Inputs
Instead of `image_url`, you can pass a video file via `video_url`. Here is a simple example using [LLaVA-OneVision](https://huggingface.co/llava-hf/llava-onevision-qwen2-0.5b-ov-hf).
@ -356,7 +356,7 @@ export VLLM_VIDEO_FETCH_TIMEOUT=<timeout>
:::
### Audio
### Audio Inputs
Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in).
Here is a simple example using Ultravox-v0.5-1B.
@ -460,77 +460,6 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
:::
### Embedding
### Embedding Inputs
vLLM's Embeddings API is a superset of OpenAI's [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings),
where a list of chat `messages` can be passed instead of batched `inputs`. This enables multi-modal inputs to be passed to embedding models.
:::{tip}
The schema of `messages` is exactly the same as in Chat Completions API.
You can refer to the above tutorials for more details on how to pass each type of multi-modal data.
:::
Usually, embedding models do not expect chat-based input, so we need to use a custom chat template to format the text and images.
Refer to the examples below for illustration.
Here is an end-to-end example using VLM2Vec. To serve the model:
```bash
vllm serve TIGER-Lab/VLM2Vec-Full --task embed \
--trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja
```
:::{important}
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass `--task embed`
to run this model in embedding mode instead of text generation mode.
The custom chat template is completely different from the original one for this model,
and can be found here: <gh-file:examples/template_vlm2vec.jinja>
:::
Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library:
```python
import requests
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model": "TIGER-Lab/VLM2Vec-Full",
"messages": [{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "Represent the given image."},
],
}],
"encoding_format": "float",
},
)
response.raise_for_status()
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
```
Below is another example, this time using the `MrLight/dse-qwen2-2b-mrl-v1` model.
```bash
vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embed \
--trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja
```
:::{important}
Like with VLM2Vec, we have to explicitly pass `--task embed`.
Additionally, `MrLight/dse-qwen2-2b-mrl-v1` requires an EOS token for embeddings, which is handled
by a custom chat template: <gh-file:examples/template_dse_qwen2_vl.jinja>
:::
:::{important}
Also important, `MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code
example below for details.
:::
Full example: <gh-file:examples/online_serving/openai_chat_embedding_client_for_multimodal.py>
TBD

View File

@ -266,11 +266,85 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
If the model has a [chat template](#chat-template), you can replace `inputs` with a list of `messages` (same schema as [Chat API](#chat-api))
which will be treated as a single prompt to the model.
:::{tip}
This enables multi-modal inputs to be passed to embedding models, see [this page](#multimodal-inputs) for details.
Code example: <gh-file:examples/online_serving/openai_embedding_client.py>
#### Multi-modal inputs
You can pass multi-modal inputs to embedding models by defining a custom chat template for the server
and passing a list of `messages` in the request. Refer to the examples below for illustration.
:::::{tab-set}
::::{tab-item} VLM2Vec
To serve the model:
```bash
vllm serve TIGER-Lab/VLM2Vec-Full --task embed \
--trust-remote-code --max-model-len 4096 --chat-template examples/template_vlm2vec.jinja
```
:::{important}
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass `--task embed`
to run this model in embedding mode instead of text generation mode.
The custom chat template is completely different from the original one for this model,
and can be found here: <gh-file:examples/template_vlm2vec.jinja>
:::
Code example: <gh-file:examples/online_serving/openai_embedding_client.py>
Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library:
```python
import requests
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
response = requests.post(
"http://localhost:8000/v1/embeddings",
json={
"model": "TIGER-Lab/VLM2Vec-Full",
"messages": [{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image_url}},
{"type": "text", "text": "Represent the given image."},
],
}],
"encoding_format": "float",
},
)
response.raise_for_status()
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
```
::::
::::{tab-item} DSE-Qwen2-MRL
To serve the model:
```bash
vllm serve MrLight/dse-qwen2-2b-mrl-v1 --task embed \
--trust-remote-code --max-model-len 8192 --chat-template examples/template_dse_qwen2_vl.jinja
```
:::{important}
Like with VLM2Vec, we have to explicitly pass `--task embed`.
Additionally, `MrLight/dse-qwen2-2b-mrl-v1` requires an EOS token for embeddings, which is handled
by a custom chat template: <gh-file:examples/template_dse_qwen2_vl.jinja>
:::
:::{important}
`MrLight/dse-qwen2-2b-mrl-v1` requires a placeholder image of the minimum image size for text query embeddings. See the full code
example below for details.
:::
::::
:::::
Full example: <gh-file:examples/online_serving/openai_chat_embedding_client_for_multimodal.py>
#### Extra parameters

View File

@ -19,6 +19,7 @@ import cloudpickle
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_in_doc_build
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
@ -368,7 +369,8 @@ class _ModelRegistry:
raise ValueError(msg)
model = _LazyRegisteredModel(*split_str)
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
elif isinstance(model_cls, type) and (is_in_doc_build() or issubclass(
model_cls, nn.Module)):
model = _RegisteredModel.from_model_cls(model_cls)
else:
msg = ("`model_cls` should be a string or PyTorch model class, "