[Doc] Create a new "Usage" section (#10827)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-05 11:19:35 +08:00 committed by GitHub
parent 8d370e91cb
commit aa39a8e175
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 218 additions and 125 deletions

View File

@ -7,7 +7,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_mm_models>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
@ -15,9 +15,6 @@ by following :ref:`this guide <adding_multimodal_plugin>`.
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
..
TODO: Add usage of --limit-mm-per-prompt when multi-image input is officially supported
Guides
++++++

View File

@ -85,12 +85,8 @@ Documentation
serving/deploying_with_nginx
serving/distributed_serving
serving/metrics
serving/env_vars
serving/usage_stats
serving/integrations
serving/tensorizer
serving/compatibility_matrix
serving/faq
.. toctree::
:maxdepth: 1
@ -99,12 +95,21 @@ Documentation
models/supported_models
models/adding_model
models/enabling_multimodal_inputs
models/engine_args
models/lora
models/vlm
models/structured_outputs
models/spec_decode
models/performance
.. toctree::
:maxdepth: 1
:caption: Usage
usage/lora
usage/multimodal_inputs
usage/structured_outputs
usage/spec_decode
usage/compatibility_matrix
usage/performance
usage/faq
usage/engine_args
usage/env_vars
usage/usage_stats
.. toctree::
:maxdepth: 1

View File

@ -3,7 +3,7 @@
Enabling Multimodal Inputs
==========================
This document walks you through the steps to extend a vLLM model so that it accepts :ref:`multi-modal <multi_modality>` inputs.
This document walks you through the steps to extend a vLLM model so that it accepts :ref:`multi-modal inputs <multimodal_inputs>`.
.. seealso::
:ref:`adding_a_new_model`

View File

@ -471,6 +471,8 @@ Sentence Pair Scoring
.. note::
These models are supported in both offline and online inference via Score API.
.. _supported_mm_models:
Multimodal Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -489,8 +491,6 @@ On the other hand, modalities separated by :code:`/` are mutually exclusive.
- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.
.. _supported_vlms:
Text Generation
---------------
@ -646,6 +646,21 @@ Text Generation
| :sup:`E` Pre-computed embeddings can be inputted for this modality.
| :sup:`+` Multiple items can be inputted per text prompt for this modality.
.. important::
To enable multiple multi-modal items per text prompt, you have to set :code:`limit_mm_per_prompt` (offline inference)
or :code:`--limit-mm-per-prompt` (online inference). For example, to enable passing up to 4 images per text prompt:
.. code-block:: python
llm = LLM(
model="Qwen/Qwen2-VL-7B-Instruct",
limit_mm_per_prompt={"image": 4},
)
.. code-block:: bash
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4
.. note::
vLLM currently only supports adding LoRA to the language backbone of multimodal models.

View File

@ -32,7 +32,7 @@ We currently support the following OpenAI APIs:
- [Completions API](https://platform.openai.com/docs/api-reference/completions)
- *Note: `suffix` parameter is not supported.*
- [Chat Completions API](https://platform.openai.com/docs/api-reference/chat)
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Using VLMs](../models/vlm.rst).
- [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Multimodal Inputs](../usage/multimodal_inputs.rst).
- *Note: `image_url.detail` parameter is not supported.*
- We also support `audio_url` content type for audio files.
- Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema.
@ -41,7 +41,7 @@ We currently support the following OpenAI APIs:
- [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings)
- Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API),
which will be treated as a single prompt to the model according to its chat template.
- This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst).
- This enables multi-modal inputs to be passed to embedding models, see [this page](../usage/multimodal_inputs.rst) for details.
- *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.*
## Score API for Cross Encoder Models

View File

@ -1,3 +1,5 @@
.. _faq:
Frequently Asked Questions
===========================

View File

@ -1,7 +1,7 @@
.. _lora:
Using LoRA adapters
===================
LoRA Adapters
=============
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.

View File

@ -1,34 +1,31 @@
.. _vlm:
.. _multimodal_inputs:
Using VLMs
==========
Multimodal Inputs
=================
vLLM provides experimental support for Vision Language Models (VLMs). See the :ref:`list of supported VLMs here <supported_vlms>`.
This document shows you how to run and serve these models using vLLM.
This page teaches you how to pass multi-modal inputs to :ref:`multi-modal models <supported_mm_models>` in vLLM.
.. note::
We are actively iterating on VLM support. See `this RFC <https://github.com/vllm-project/vllm/issues/4194>`_ for upcoming changes,
We are actively iterating on multi-modal support. See `this RFC <https://github.com/vllm-project/vllm/issues/4194>`_ for upcoming changes,
and `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
Offline Inference
-----------------
Single-image input
^^^^^^^^^^^^^^^^^^
The :class:`~vllm.LLM` class can be instantiated in much the same way as language-only models.
.. code-block:: python
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:
To input multi-modal data, follow this schema in :class:`vllm.inputs.PromptType`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
Image
^^^^^
You can pass a single image to the :code:`'image'` field of the multi-modal dictionary, as shown in the following examples:
.. code-block:: python
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
@ -41,41 +38,6 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
"multi_modal_data": {"image": image},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Inference with image embeddings as input
image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": image_embeds},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
# Inference with image embeddings as input with additional parameters
# Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters.
mm_data = {}
image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
# For Qwen2VL, image_grid_thw is needed to calculate positional encoding.
mm_data['image'] = {
"image_embeds": image_embeds,
"image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3),
}
# For MiniCPM-V, image_size_list is needed to calculate details of the sliced image.
mm_data['image'] = {
"image_embeds": image_embeds,
"image_size_list": [image.size] # list of image sizes
}
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": mm_data,
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
@ -102,12 +64,7 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT
A code example can be found in `examples/offline_inference_vision_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language.py>`_.
Multi-image input
^^^^^^^^^^^^^^^^^
Multi-image input is only supported for a subset of VLMs, as shown :ref:`here <supported_vlms>`.
To enable multiple multi-modal items per text prompt, you have to set ``limit_mm_per_prompt`` for the :class:`~vllm.LLM` class.
To substitute multiple images inside the same text prompt, you can pass in a list of images instead:
.. code-block:: python
@ -118,10 +75,6 @@ To enable multiple multi-modal items per text prompt, you have to set ``limit_mm
limit_mm_per_prompt={"image": 2}, # The maximum number to accept
)
Instead of passing in a single image, you can pass in a list of images.
.. code-block:: python
# Refer to the HuggingFace repo for the correct format to use
prompt = "<|user|>\n<|image_1|>\n<|image_2|>\nWhat is the content of each image?<|end|>\n<|assistant|>\n"
@ -169,30 +122,114 @@ Multi-image input can be extended to perform video captioning. We show this with
generated_text = o.outputs[0].text
print(generated_text)
Video
^^^^^
You can pass a list of NumPy arrays directly to the :code:`'video'` field of the multi-modal dictionary
instead of using multi-image input.
Please refer to `examples/offline_inference_vision_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language.py>`_ for more details.
Audio
^^^^^
You can pass a tuple :code:`(array, sampling_rate)` to the :code:`'audio'` field of the multi-modal dictionary.
Please refer to `examples/offline_inference_audio_language.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_audio_language.py>`_ for more details.
Embedding
^^^^^^^^^
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape :code:`(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary.
.. code-block:: python
# Inference with image embeddings as input
llm = LLM(model="llava-hf/llava-1.5-7b-hf")
# Refer to the HuggingFace repo for the correct format to use
prompt = "USER: <image>\nWhat is the content of this image?\nASSISTANT:"
# Embeddings for single image
# torch.Tensor of shape (1, image_feature_size, hidden_size of LM)
image_embeds = torch.load(...)
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": image_embeds},
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embeddings:
.. code-block:: python
# Construct the prompt based on your model
prompt = ...
# Embeddings for multiple images
# torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM)
image_embeds = torch.load(...)
# Qwen2-VL
llm = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": 4})
mm_data = {
"image": {
"image_embeds": image_embeds,
# image_grid_thw is needed to calculate positional encoding.
"image_grid_thw": torch.load(...), # torch.Tensor of shape (1, 3),
}
}
# MiniCPM-V
llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={"image": 4})
mm_data = {
"image": {
"image_embeds": image_embeds,
# image_size_list is needed to calculate details of the sliced image.
"image_size_list": [image.size for image in images], # list of image sizes
}
}
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": mm_data,
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
Online Inference
----------------
OpenAI Vision API
^^^^^^^^^^^^^^^^^
Our OpenAI-compatible server accepts multi-modal data via the `Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`_.
You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.
.. important::
A chat template is **required** to use Chat Completions API.
Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruct`` with vLLM's OpenAI-compatible API server.
Although most models come with a chat template, for others you have to define one yourself.
The chat template can be inferred based on the documentation on the model's HuggingFace repo.
For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`__.
Image
^^^^^
Image input is supported according to `OpenAI Vision API <https://platform.openai.com/docs/guides/vision>`_.
Here is a simple example using Phi-3.5-Vision.
First, launch the OpenAI-compatible server:
.. code-block:: bash
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
.. important::
Since OpenAI Vision API is based on `Chat Completions API <https://platform.openai.com/docs/api-reference/chat>`_,
a chat template is **required** to launch the API server.
Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it.
The chat template can be inferred based on the documentation on the model's HuggingFace repo.
For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/template_llava.jinja>`_.
To consume the server, you can use the OpenAI client like in the example below:
Then, you can use the OpenAI client as follows:
.. code-block:: python
@ -252,22 +289,59 @@ A full code example can be found in `examples/openai_chat_completion_client_for_
.. note::
By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable:
By default, the timeout for fetching images through HTTP URL is ``5`` seconds.
You can override this by setting the environment variable:
.. code-block:: console
$ export VLLM_IMAGE_FETCH_TIMEOUT=<timeout>
Chat Embeddings API
^^^^^^^^^^^^^^^^^^^
Video
^^^^^
vLLM's Chat Embeddings API is a superset of OpenAI's `Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`_,
where a list of ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models.
Instead of :code:`image_url`, you can pass a video file via :code:`video_url`.
You can use `these tests <https://github.com/vllm-project/vllm/blob/main/tests/entrypoints/openai/test_video.py>`_ as reference.
.. note::
By default, the timeout for fetching videos through HTTP URL url is ``30`` seconds.
You can override this by setting the environment variable:
.. code-block:: console
$ export VLLM_VIDEO_FETCH_TIMEOUT=<timeout>
Audio
^^^^^
Instead of :code:`image_url`, you can pass an audio file via :code:`audio_url`.
A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_chat_completion_client_for_multimodal.py>`_.
.. note::
By default, the timeout for fetching audios through HTTP URL is ``10`` seconds.
You can override this by setting the environment variable:
.. code-block:: console
$ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
Embedding
^^^^^^^^^
vLLM's Embeddings API is a superset of OpenAI's `Embeddings API <https://platform.openai.com/docs/api-reference/embeddings>`_,
where a list of chat ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models.
.. tip::
The schema of ``messages`` is exactly the same as in Chat Completions API.
You can refer to the above tutorials for more details on how to pass each type of multi-modal data.
In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
Usually, embedding models do not expect chat-based input, so we need to use a custom chat template to format the text and images.
Refer to the examples below for illustration.
Here is an end-to-end example using VLM2Vec. To serve the model:
.. code-block:: bash
@ -279,10 +353,8 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model.
Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding``
to run this model in embedding mode instead of text generation mode.
.. important::
VLM2Vec does not expect chat-based input. We use a `custom chat template <https://github.com/vllm-project/vllm/blob/main/examples/template_vlm2vec.jinja>`_
to combine the text and images together.
The custom chat template is completely different from the original one for this model,
and can be found `here <https://github.com/vllm-project/vllm/blob/main/examples/template_vlm2vec.jinja>`__.
Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library:
@ -310,7 +382,7 @@ Since the request schema is not defined by OpenAI client, we post a request to t
response_json = response.json()
print("Embedding output:", response_json["data"][0]["embedding"])
Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
Below is another example, this time using the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
.. code-block:: bash
@ -319,8 +391,10 @@ Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model.
.. important::
Like with VLM2Vec, we have to explicitly pass ``--task embedding``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings,
which is handled by the jinja template.
Like with VLM2Vec, we have to explicitly pass ``--task embedding``.
Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings, which is handled
by `this custom chat template <https://github.com/vllm-project/vllm/blob/main/examples/template_dse_qwen2_vl.jinja>`__.
.. important::

View File

@ -1,7 +1,7 @@
.. _spec_decode:
Speculative decoding in vLLM
============================
Speculative decoding
====================
.. warning::
Please note that speculative decoding in vLLM is not yet optimized and does
@ -182,7 +182,7 @@ speculative decoding, breaking down the guarantees into three key areas:
3. **vLLM Logprob Stability**
- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
same request across runs. For more details, see the FAQ section
titled *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq>`_.
titled *Can the output of a prompt vary across runs in vLLM?* in the :ref:`FAQs <faq>`.
**Conclusion**
@ -197,7 +197,7 @@ can occur due to following factors:
**Mitigation Strategies**
For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq>`_.
For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the :ref:`FAQs <faq>`.
Resources for vLLM contributors
-------------------------------

View File

@ -430,7 +430,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "

View File

@ -509,7 +509,7 @@ class ModelConfig:
self.use_async_output_proc = False
return
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"):
logger.warning(
@ -525,7 +525,7 @@ class ModelConfig:
self.use_async_output_proc = False
return
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if device_config.device_type == "cuda" and self.enforce_eager:
logger.warning(
@ -540,7 +540,7 @@ class ModelConfig:
if self.task == "embedding":
self.use_async_output_proc = False
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if speculative_config:
logger.warning("Async output processing is not supported with"
@ -1704,7 +1704,7 @@ class LoRAConfig:
model_config.quantization)
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if scheduler_config.chunked_prefill_enabled:
raise ValueError("LoRA is not supported with chunked prefill yet.")

View File

@ -1111,7 +1111,7 @@ class EngineArgs:
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if self.num_scheduler_steps > 1:
if speculative_config is not None:

View File

@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@staticmethod
@functools.lru_cache
def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
logger.warning(
"Prompt logprob is not supported by multi step workers. "

View File

@ -23,7 +23,7 @@ class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
assert self.lora_config is None, "cpu backend doesn't support LoRA"

View File

@ -46,7 +46,7 @@ class CpuPlatform(Platform):
import vllm.envs as envs
from vllm.utils import GiB_bytes
model_config = vllm_config.model_config
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if not model_config.enforce_eager:
logger.warning(

View File

@ -104,7 +104,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
return spec_decode_worker
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""Worker which implements speculative decoding.

View File

@ -47,7 +47,7 @@ logger = init_logger(__name__)
# Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
STR_NOT_IMPL_ENC_DEC_SWA = \

View File

@ -817,7 +817,7 @@ def _pythonize_sampler_output(
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
# (Check for Guided Decoding)
if seq_group.sampling_params.logits_processors:

View File

@ -13,7 +13,7 @@ def assert_enc_dec_mr_supported_scenario(
a supported scenario.
'''
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
# If the feature combo become valid
if enc_dec_mr.cache_config.enable_prefix_caching: