[Doc] Create a new "Usage" section (#10827)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
8d370e91cb
commit
aa39a8e175
@ -7,7 +7,7 @@ Multi-Modality
|
|||||||
|
|
||||||
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||||
|
|
||||||
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_mm_models>`
|
||||||
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
|
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
|
||||||
|
|
||||||
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
||||||
@ -15,9 +15,6 @@ by following :ref:`this guide <adding_multimodal_plugin>`.
|
|||||||
|
|
||||||
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
|
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
|
||||||
|
|
||||||
..
|
|
||||||
TODO: Add usage of --limit-mm-per-prompt when multi-image input is officially supported
|
|
||||||
|
|
||||||
Guides
|
Guides
|
||||||
++++++
|
++++++
|
||||||
|
|
||||||
|
@ -85,12 +85,8 @@ Documentation
|
|||||||
serving/deploying_with_nginx
|
serving/deploying_with_nginx
|
||||||
serving/distributed_serving
|
serving/distributed_serving
|
||||||
serving/metrics
|
serving/metrics
|
||||||
serving/env_vars
|
|
||||||
serving/usage_stats
|
|
||||||
serving/integrations
|
serving/integrations
|
||||||
serving/tensorizer
|
serving/tensorizer
|
||||||
serving/compatibility_matrix
|
|
||||||
serving/faq
|
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
@ -99,12 +95,21 @@ Documentation
|
|||||||
models/supported_models
|
models/supported_models
|
||||||
models/adding_model
|
models/adding_model
|
||||||
models/enabling_multimodal_inputs
|
models/enabling_multimodal_inputs
|
||||||
models/engine_args
|
|
||||||
models/lora
|
.. toctree::
|
||||||
models/vlm
|
:maxdepth: 1
|
||||||
models/structured_outputs
|
:caption: Usage
|
||||||
models/spec_decode
|
|
||||||
models/performance
|
usage/lora
|
||||||
|
usage/multimodal_inputs
|
||||||
|
usage/structured_outputs
|
||||||
|
usage/spec_decode
|
||||||
|
usage/compatibility_matrix
|
||||||
|
usage/performance
|
||||||
|
usage/faq
|
||||||
|
usage/engine_args
|
||||||
|
usage/env_vars
|
||||||
|
usage/usage_stats
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
Enabling Multimodal Inputs
|
Enabling Multimodal Inputs
|
||||||
==========================
|
==========================
|
||||||
|
|
||||||
This document walks you through the steps to extend a vLLM model so that it accepts :ref:`multi-modal <multi_modality>` inputs.
|
This document walks you through the steps to extend a vLLM model so that it accepts :ref:`multi-modal inputs <multimodal_inputs>`.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
:ref:`adding_a_new_model`
|
:ref:`adding_a_new_model`
|
||||||
|
@ -471,6 +471,8 @@ Sentence Pair Scoring
|
|||||||
.. note::
|
.. note::
|
||||||
These models are supported in both offline and online inference via Score API.
|
These models are supported in both offline and online inference via Score API.
|
||||||
|
|
||||||
|
.. _supported_mm_models:
|
||||||
|
|
||||||
Multimodal Language Models
|
Multimodal Language Models
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
@ -489,8 +491,6 @@ On the other hand, modalities separated by :code:`/` are mutually exclusive.
|
|||||||
|
|
||||||
- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.
|
- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.
|
||||||
|
|
||||||
.. _supported_vlms:
|
|
||||||
|
|
||||||
Text Generation
|
Text Generation
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
@ -646,6 +646,21 @@ Text Generation
|
|||||||
| :sup:`E` Pre-computed embeddings can be inputted for this modality.
|
| :sup:`E` Pre-computed embeddings can be inputted for this modality.
|
||||||
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
|
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
To enable multiple multi-modal items per text prompt, you have to set :code:`limit_mm_per_prompt` (offline inference)
|
||||||
|
or :code:`--limit-mm-per-prompt` (online inference). For example, to enable passing up to 4 images per text prompt:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="Qwen/Qwen2-VL-7B-Instruct",
|
||||||
|
limit_mm_per_prompt={"image": 4},
|
||||||
|
)
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
|
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ We currently support the following OpenAI APIs:
|
|||||||
- [Completions API](https://platform.openai.com/docs/api-reference/completions)
|
- [Completions API](https://platform.openai.com/docs/api-reference/completions)
|
||||||
- *Note: `suffix` parameter is not supported.*
|
- *Note: `suffix` parameter is not supported.*
|
||||||
- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat)
|
- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat)
|
||||||
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Using VLMs](../models/vlm.rst).
|
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Multimodal Inputs](../usage/multimodal_inputs.rst).
|
||||||
- *Note: `image_url.detail` parameter is not supported.*
|
- *Note: `image_url.detail` parameter is not supported.*
|
||||||
- We also support `audio_url` content type for audio files.
|
- We also support `audio_url` content type for audio files.
|
||||||
- Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema.
|
- Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema.
|
||||||
@ -41,7 +41,7 @@ We currently support the following OpenAI APIs:
|
|||||||
- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings)
|
- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings)
|
||||||
- Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API),
|
- Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API),
|
||||||
which will be treated as a single prompt to the model according to its chat template.
|
which will be treated as a single prompt to the model according to its chat template.
|
||||||
- This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst).
|
- This enables multi-modal inputs to be passed to embedding models, see [this page](../usage/multimodal_inputs.rst) for details.
|
||||||
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
|
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
|
||||||
|
|
||||||
## Score API for Cross Encoder Models
|
## Score API for Cross Encoder Models
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
.. _faq:
|
||||||
|
|
||||||
Frequently Asked Questions
|
Frequently Asked Questions
|
||||||
===========================
|
===========================
|
||||||
|
|
@ -1,7 +1,7 @@
|
|||||||
.. _lora:
|
.. _lora:
|
||||||
|
|
||||||
Using LoRA adapters
|
LoRA Adapters
|
||||||
===================
|
=============
|
||||||
|
|
||||||
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
|
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
|
||||||
|
|
@ -1,34 +1,31 @@
|
|||||||
.. _vlm:
|
.. _multimodal_inputs:
|
||||||
|
|
||||||
Using VLMs
|
Multimodal Inputs
|
||||||
==========
|
=================
|
||||||
|
|
||||||
vLLM provides experimental support for Vision Language Models (VLMs). See the :ref:`list of supported VLMs here <supported_vlms>`.
|
This page teaches you how to pass multi-modal inputs to :ref:`multi-modal models <supported_mm_models>` in vLLM.
|
||||||
This document shows you how to run and serve these models using vLLM.
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
We are actively iterating on VLM support. See `this RFC <https://github.com/vllm-project/vllm/issues/4194>`_ for upcoming changes,
|
We are actively iterating on multi-modal support. See `this RFC <https://github.com/vllm-project/vllm/issues/4194>`_ for upcoming changes,
|
||||||
and `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
|
and `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
|
||||||
|
|
||||||
Offline Inference
|
Offline Inference
|
||||||
-----------------
|
-----------------
|
||||||
|
|
||||||
Single-image input
|
To input multi-modal data, follow this schema in :class:`vllm.inputs.PromptType`:
|
||||||
^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
The :class:`~vllm.LLM` class can be instantiated in much the same way as language-only models.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
|
|
||||||
|
|
||||||
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:
|
|
||||||
|
|
||||||
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
||||||
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
||||||
|
|
||||||
|
Image
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
You can pass a single image to the :code:`'image'` field of the multi-modal dictionary, as shown in the following examples:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
|
||||||
|
|
||||||
# Refer to the HuggingFace repo for the correct format to use
|
# Refer to the HuggingFace repo for the correct format to use
|
||||||
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
|
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
|
||||||
|
|
||||||
@ -41,41 +38,6 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
|
|||||||
"multi_modal_data": {"image": image},
|
"multi_modal_data": {"image": image},
|
||||||
})
|
})
|
||||||
|
|
||||||
for o in outputs:
|
|
||||||
generated_text = o.outputs[0].text
|
|
||||||
print(generated_text)
|
|
||||||
|
|
||||||
# Inference with image embeddings as input
|
|
||||||
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
|
|
||||||
outputs = llm.generate({
|
|
||||||
"prompt": prompt,
|
|
||||||
"multi_modal_data": {"image": image_embeds},
|
|
||||||
})
|
|
||||||
|
|
||||||
for o in outputs:
|
|
||||||
generated_text = o.outputs[0].text
|
|
||||||
print(generated_text)
|
|
||||||
|
|
||||||
# Inference with image embeddings as input with additional parameters
|
|
||||||
# Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters.
|
|
||||||
mm_data = {}
|
|
||||||
|
|
||||||
image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
|
|
||||||
# For Qwen2VL, image_grid_thw is needed to calculate positional encoding.
|
|
||||||
mm_data['image'] = {
|
|
||||||
"image_embeds": image_embeds,
|
|
||||||
"image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3),
|
|
||||||
}
|
|
||||||
# For MiniCPM-V, image_size_list is needed to calculate details of the sliced image.
|
|
||||||
mm_data['image'] = {
|
|
||||||
"image_embeds": image_embeds,
|
|
||||||
"image_size_list": [image.size] # list of image sizes
|
|
||||||
}
|
|
||||||
outputs = llm.generate({
|
|
||||||
"prompt": prompt,
|
|
||||||
"multi_modal_data": mm_data,
|
|
||||||
})
|
|
||||||
|
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
generated_text = o.outputs[0].text
|
generated_text = o.outputs[0].text
|
||||||
print(generated_text)
|
print(generated_text)
|
||||||
@ -102,12 +64,7 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
|
|||||||
|
|
||||||
A code example can be found in `examples/offline_inference_vision_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language.py>`_.
|
A code example can be found in `examples/offline_inference_vision_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language.py>`_.
|
||||||
|
|
||||||
Multi-image input
|
To substitute multiple images inside the same text prompt, you can pass in a list of images instead:
|
||||||
^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
Multi-image input is only supported for a subset of VLMs, as shown :ref:`here <supported_vlms>`.
|
|
||||||
|
|
||||||
To enable multiple multi-modal items per text prompt, you have to set ``limit_mm_per_prompt`` for the :class:`~vllm.LLM` class.
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -118,10 +75,6 @@ To enable multiple multi-modal items per text prompt, you have to set ``limit_mm
|
|||||||
limit_mm_per_prompt={"image": 2}, # The maximum number to accept
|
limit_mm_per_prompt={"image": 2}, # The maximum number to accept
|
||||||
)
|
)
|
||||||
|
|
||||||
Instead of passing in a single image, you can pass in a list of images.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
# Refer to the HuggingFace repo for the correct format to use
|
# Refer to the HuggingFace repo for the correct format to use
|
||||||
prompt = "<|user|>\n<|image_1|>\n<|image_2|>\nWhat is the content of each image?<|end|>\n<|assistant|>\n"
|
prompt = "<|user|>\n<|image_1|>\n<|image_2|>\nWhat is the content of each image?<|end|>\n<|assistant|>\n"
|
||||||
|
|
||||||
@ -169,30 +122,114 @@ Multi-image input can be extended to perform video captioning. We show this with
|
|||||||
generated_text = o.outputs[0].text
|
generated_text = o.outputs[0].text
|
||||||
print(generated_text)
|
print(generated_text)
|
||||||
|
|
||||||
|
Video
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
You can pass a list of NumPy arrays directly to the :code:`'video'` field of the multi-modal dictionary
|
||||||
|
instead of using multi-image input.
|
||||||
|
|
||||||
|
Please refer to `examples/offline_inference_vision_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language.py>`_ for more details.
|
||||||
|
|
||||||
|
Audio
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
You can pass a tuple :code:`(array, sampling_rate)` to the :code:`'audio'` field of the multi-modal dictionary.
|
||||||
|
|
||||||
|
Please refer to `examples/offline_inference_audio_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_audio_language.py>`_ for more details.
|
||||||
|
|
||||||
|
Embedding
|
||||||
|
^^^^^^^^^
|
||||||
|
|
||||||
|
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
|
||||||
|
pass a tensor of shape :code:`(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Inference with image embeddings as input
|
||||||
|
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
|
||||||
|
|
||||||
|
# Refer to the HuggingFace repo for the correct format to use
|
||||||
|
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
|
||||||
|
|
||||||
|
# Embeddings for single image
|
||||||
|
# torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
|
||||||
|
image_embeds = torch.load(...)
|
||||||
|
|
||||||
|
outputs = llm.generate({
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {"image": image_embeds},
|
||||||
|
})
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embeddings:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Construct the prompt based on your model
|
||||||
|
prompt = ...
|
||||||
|
|
||||||
|
# Embeddings for multiple images
|
||||||
|
# torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
|
||||||
|
image_embeds = torch.load(...)
|
||||||
|
|
||||||
|
# Qwen2-VL
|
||||||
|
llm = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": 4})
|
||||||
|
mm_data = {
|
||||||
|
"image": {
|
||||||
|
"image_embeds": image_embeds,
|
||||||
|
# image_grid_thw is needed to calculate positional encoding.
|
||||||
|
"image_grid_thw": torch.load(...), # torch.Tensor of shape (1, 3),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# MiniCPM-V
|
||||||
|
llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={"image": 4})
|
||||||
|
mm_data = {
|
||||||
|
"image": {
|
||||||
|
"image_embeds": image_embeds,
|
||||||
|
# image_size_list is needed to calculate details of the sliced image.
|
||||||
|
"image_size_list": [image.size for image in images], # list of image sizes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outputs = llm.generate({
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": mm_data,
|
||||||
|
})
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
Online Inference
|
Online Inference
|
||||||
----------------
|
----------------
|
||||||
|
|
||||||
OpenAI Vision API
|
Our OpenAI-compatible server accepts multi-modal data via the `Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`_.
|
||||||
^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.
|
.. important::
|
||||||
|
A chat template is **required** to use Chat Completions API.
|
||||||
|
|
||||||
Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruct`` with vLLM's OpenAI-compatible API server.
|
Although most models come with a chat template, for others you have to define one yourself.
|
||||||
|
The chat template can be inferred based on the documentation on the model's HuggingFace repo.
|
||||||
|
For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`__.
|
||||||
|
|
||||||
|
Image
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
Image input is supported according to `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.
|
||||||
|
Here is a simple example using Phi-3.5-Vision.
|
||||||
|
|
||||||
|
First, launch the OpenAI-compatible server:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||||
|
|
||||||
.. important::
|
Then, you can use the OpenAI client as follows:
|
||||||
Since OpenAI Vision API is based on `Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`_,
|
|
||||||
a chat template is **required** to launch the API server.
|
|
||||||
|
|
||||||
Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it.
|
|
||||||
The chat template can be inferred based on the documentation on the model's HuggingFace repo.
|
|
||||||
For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`_.
|
|
||||||
|
|
||||||
To consume the server, you can use the OpenAI client like in the example below:
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -252,22 +289,59 @@ A full code example can be found in `examples/openai_chat_completion_client_for_
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable:
|
By default, the timeout for fetching images through HTTP URL is ``5`` seconds.
|
||||||
|
You can override this by setting the environment variable:
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
$ export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
|
$ export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
|
||||||
|
|
||||||
Chat Embeddings API
|
Video
|
||||||
^^^^^^^^^^^^^^^^^^^
|
^^^^^
|
||||||
|
|
||||||
vLLM's Chat Embeddings API is a superset of OpenAI's `Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`_,
|
Instead of :code:`image_url`, you can pass a video file via :code:`video_url`.
|
||||||
where a list of ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models.
|
|
||||||
|
You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/entrypoints/openai/test_video.py>`_ as reference.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
By default, the timeout for fetching videos through HTTP URL url is ``30`` seconds.
|
||||||
|
You can override this by setting the environment variable:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ export VLLM_VIDEO_FETCH_TIMEOUT=<timeout>
|
||||||
|
|
||||||
|
Audio
|
||||||
|
^^^^^
|
||||||
|
|
||||||
|
Instead of :code:`image_url`, you can pass an audio file via :code:`audio_url`.
|
||||||
|
|
||||||
|
A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client_for_multimodal.py>`_.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
By default, the timeout for fetching audios through HTTP URL is ``10`` seconds.
|
||||||
|
You can override this by setting the environment variable:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
|
||||||
|
|
||||||
|
Embedding
|
||||||
|
^^^^^^^^^
|
||||||
|
|
||||||
|
vLLM's Embeddings API is a superset of OpenAI's `Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`_,
|
||||||
|
where a list of chat ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models.
|
||||||
|
|
||||||
.. tip::
|
.. tip::
|
||||||
The schema of ``messages`` is exactly the same as in Chat Completions API.
|
The schema of ``messages`` is exactly the same as in Chat Completions API.
|
||||||
|
You can refer to the above tutorials for more details on how to pass each type of multi-modal data.
|
||||||
|
|
||||||
In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
|
Usually, embedding models do not expect chat-based input, so we need to use a custom chat template to format the text and images.
|
||||||
|
Refer to the examples below for illustration.
|
||||||
|
|
||||||
|
Here is an end-to-end example using VLM2Vec. To serve the model:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
@ -279,10 +353,8 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
|
|||||||
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
|
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
|
||||||
to run this model in embedding mode instead of text generation mode.
|
to run this model in embedding mode instead of text generation mode.
|
||||||
|
|
||||||
.. important::
|
The custom chat template is completely different from the original one for this model,
|
||||||
|
and can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/template_vlm2vec.jinja>`__.
|
||||||
VLM2Vec does not expect chat-based input. We use a `custom chat template <https://github.com/vllm-project/vllm/blob/main/examples/template_vlm2vec.jinja>`_
|
|
||||||
to combine the text and images together.
|
|
||||||
|
|
||||||
Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
|
Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
|
||||||
|
|
||||||
@ -310,7 +382,7 @@ Since the request schema is not defined by OpenAI client, we post a request to t
|
|||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
print("Embedding output:", response_json["data"][0]["embedding"])
|
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||||
|
|
||||||
Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
|
Below is another example, this time using the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
@ -319,8 +391,10 @@ Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
|
|||||||
|
|
||||||
.. important::
|
.. important::
|
||||||
|
|
||||||
Like with VLM2Vec, we have to explicitly pass ``--task embedding``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings,
|
Like with VLM2Vec, we have to explicitly pass ``--task embedding``.
|
||||||
which is handled by the jinja template.
|
|
||||||
|
Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings, which is handled
|
||||||
|
by `this custom chat template <https://github.com/vllm-project/vllm/blob/main/examples/template_dse_qwen2_vl.jinja>`__.
|
||||||
|
|
||||||
.. important::
|
.. important::
|
||||||
|
|
@ -1,7 +1,7 @@
|
|||||||
.. _spec_decode:
|
.. _spec_decode:
|
||||||
|
|
||||||
Speculative decoding in vLLM
|
Speculative decoding
|
||||||
============================
|
====================
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
Please note that speculative decoding in vLLM is not yet optimized and does
|
Please note that speculative decoding in vLLM is not yet optimized and does
|
||||||
@ -182,7 +182,7 @@ speculative decoding, breaking down the guarantees into three key areas:
|
|||||||
3. **vLLM Logprob Stability**
|
3. **vLLM Logprob Stability**
|
||||||
- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
|
- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
|
||||||
same request across runs. For more details, see the FAQ section
|
same request across runs. For more details, see the FAQ section
|
||||||
titled *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq>`_.
|
titled *Can the output of a prompt vary across runs in vLLM?* in the :ref:`FAQs <faq>`.
|
||||||
|
|
||||||
|
|
||||||
**Conclusion**
|
**Conclusion**
|
||||||
@ -197,7 +197,7 @@ can occur due to following factors:
|
|||||||
|
|
||||||
**Mitigation Strategies**
|
**Mitigation Strategies**
|
||||||
|
|
||||||
For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq>`_.
|
For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the :ref:`FAQs <faq>`.
|
||||||
|
|
||||||
Resources for vLLM contributors
|
Resources for vLLM contributors
|
||||||
-------------------------------
|
-------------------------------
|
@ -430,7 +430,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
@ -509,7 +509,7 @@ class ModelConfig:
|
|||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
return
|
return
|
||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"):
|
if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -525,7 +525,7 @@ class ModelConfig:
|
|||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
return
|
return
|
||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if device_config.device_type == "cuda" and self.enforce_eager:
|
if device_config.device_type == "cuda" and self.enforce_eager:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -540,7 +540,7 @@ class ModelConfig:
|
|||||||
if self.task == "embedding":
|
if self.task == "embedding":
|
||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if speculative_config:
|
if speculative_config:
|
||||||
logger.warning("Async output processing is not supported with"
|
logger.warning("Async output processing is not supported with"
|
||||||
@ -1704,7 +1704,7 @@ class LoRAConfig:
|
|||||||
model_config.quantization)
|
model_config.quantization)
|
||||||
|
|
||||||
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
|
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if scheduler_config.chunked_prefill_enabled:
|
if scheduler_config.chunked_prefill_enabled:
|
||||||
raise ValueError("LoRA is not supported with chunked prefill yet.")
|
raise ValueError("LoRA is not supported with chunked prefill yet.")
|
||||||
|
@ -1111,7 +1111,7 @@ class EngineArgs:
|
|||||||
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if self.num_scheduler_steps > 1:
|
if self.num_scheduler_steps > 1:
|
||||||
if speculative_config is not None:
|
if speculative_config is not None:
|
||||||
|
@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def _log_prompt_logprob_unsupported_warning_once():
|
def _log_prompt_logprob_unsupported_warning_once():
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Prompt logprob is not supported by multi step workers. "
|
"Prompt logprob is not supported by multi step workers. "
|
||||||
|
@ -23,7 +23,7 @@ class CPUExecutor(ExecutorBase):
|
|||||||
|
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
assert self.device_config.device_type == "cpu"
|
assert self.device_config.device_type == "cpu"
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
assert self.lora_config is None, "cpu backend doesn't support LoRA"
|
assert self.lora_config is None, "cpu backend doesn't support LoRA"
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class CpuPlatform(Platform):
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.utils import GiB_bytes
|
from vllm.utils import GiB_bytes
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
if not model_config.enforce_eager:
|
if not model_config.enforce_eager:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -104,7 +104,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
|||||||
return spec_decode_worker
|
return spec_decode_worker
|
||||||
|
|
||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
"""Worker which implements speculative decoding.
|
"""Worker which implements speculative decoding.
|
||||||
|
@ -47,7 +47,7 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
# Exception strings for non-implemented encoder/decoder scenarios
|
# Exception strings for non-implemented encoder/decoder scenarios
|
||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
|
|
||||||
STR_NOT_IMPL_ENC_DEC_SWA = \
|
STR_NOT_IMPL_ENC_DEC_SWA = \
|
||||||
|
@ -817,7 +817,7 @@ def _pythonize_sampler_output(
|
|||||||
|
|
||||||
for sgdx, (seq_group,
|
for sgdx, (seq_group,
|
||||||
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
sample_result) in enumerate(zip(seq_groups, samples_list)):
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
# (Check for Guided Decoding)
|
# (Check for Guided Decoding)
|
||||||
if seq_group.sampling_params.logits_processors:
|
if seq_group.sampling_params.logits_processors:
|
||||||
|
@ -13,7 +13,7 @@ def assert_enc_dec_mr_supported_scenario(
|
|||||||
a supported scenario.
|
a supported scenario.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||||
# If the feature combo become valid
|
# If the feature combo become valid
|
||||||
|
|
||||||
if enc_dec_mr.cache_config.enable_prefix_caching:
|
if enc_dec_mr.cache_config.enable_prefix_caching:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user