[Doc][2/N] Reorganize Models and Usage sections (#11755)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-06 21:40:31 +08:00 committed by GitHub
parent 996357e480
commit ee77fdb5de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 265 additions and 238 deletions

View File

@ -9,7 +9,7 @@ body:
value: >
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
#### We also highly recommend you read https://docs.vllm.ai/en/latest/models/adding_model.html first to understand how to add a new model.
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model.
- type: textarea
attributes:
label: The model to consider.

View File

Before

Width:  |  Height:  |  Size: 102 KiB

After

Width:  |  Height:  |  Size: 102 KiB

View File

Before

Width:  |  Height:  |  Size: 173 KiB

After

Width:  |  Height:  |  Size: 173 KiB

View File

@ -0,0 +1,102 @@
(new-model-basic)=
# Basic Implementation
This guide walks you through the steps to implement a basic vLLM model.
## 1. Bring your model code
First, clone the PyTorch model code from the source repository.
For instance, vLLM's [OPT model](gh-file:vllm/model_executor/models/opt.py) was adapted from
HuggingFace's [modeling_opt.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py) file.
```{warning}
Make sure to review and adhere to the original code's copyright and licensing terms!
```
## 2. Make your code compatible with vLLM
To ensure compatibility with vLLM, your model must meet the following requirements:
### Initialization Code
All vLLM modules within the model must include a `prefix` argument in their constructor. This `prefix` is typically the full name of the module in the model's state dictionary and is crucial for:
- Runtime support: vLLM's attention operators are registered in a model's state by their full names. Each attention operator must have a unique prefix as its layer name to avoid conflicts.
- Non-uniform quantization support: A quantized checkpoint can selectively quantize certain layers while keeping others in full precision. By providing the `prefix` during initialization, vLLM can match the current layer's `prefix` with the quantization configuration to determine if the layer should be initialized in quantized mode.
The initialization code should look like this:
```python
from torch import nn
from vllm.config import VllmConfig
from vllm.attention import Attention
class MyAttention(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.attn = Attention(prefix=f"{prefix}.attn")
class MyDecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.self_attn = MyAttention(prefix=f"{prefix}.self_attn")
class MyModel(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.layers = nn.ModuleList(
[MyDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") for i in range(vllm_config.model_config.hf_config.num_hidden_layers)]
)
class MyModelForCausalLM(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = MyModel(vllm_config, prefix=f"{prefix}.model")
```
### Computation Code
Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
```python
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
```
```{note}
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
```
For reference, check out our [Llama implementation](gh-file:vllm/model_executor/models/llama.py). vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out <gh-dir:vllm/model_executor/models> for more examples.
## 3. (Optional) Implement tensor parallelism and quantization support
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
For the embedding layer, you can simply replace {class}`torch.nn.Embedding` with `VocabParallelEmbedding`. For the output LM head, you can use `ParallelLMHead`.
When it comes to the linear layers, we provide the following options to parallelize them:
- `ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
- `RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.
- `ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.
- `MergedColumnParallelLinear`: Column-parallel linear that merges multiple `ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
- `QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.
Note that all the linear layers above take `linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
## 4. Implement the weight loading logic
You now need to implement the `load_weights` method in your `*ForCausalLM` class.
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for `MergedColumnParallelLinear` and `QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
## 5. Register your model
See [this page](#new-model-registration) for instructions on how to register your new model to be used by vLLM.

View File

@ -0,0 +1,26 @@
(new-model)=
# Adding a New Model
This section provides more information on how to integrate a [HuggingFace Transformers](https://github.com/huggingface/transformers) model into vLLM.
```{toctree}
:caption: Contents
:maxdepth: 1
basic
registration
multimodal
```
```{note}
The complexity of adding a new model depends heavily on the model's architecture.
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
```
```{tip}
If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues)
or ask on our [developer slack](https://slack.vllm.ai).
We will be happy to help you out!
```

View File

@ -2,15 +2,11 @@
# Enabling Multimodal Inputs
This document walks you through the steps to extend a vLLM model so that it accepts [multi-modal inputs](#multimodal-inputs).
```{seealso}
[Adding a New Model](adding-a-new-model)
```
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](#multimodal-inputs).
## 1. Update the base vLLM model
It is assumed that you have already implemented the model in vLLM according to [these steps](#adding-a-new-model).
It is assumed that you have already implemented the model in vLLM according to [these steps](#new-model-basic).
Further update the model as follows:
- Implement the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.

View File

@ -0,0 +1,56 @@
(new-model-registration)=
# Model Registration
vLLM relies on a model registry to determine how to run each model.
A list of pre-registered architectures can be found on the [Supported Models](#supported-models) page.
If your model is not on this list, you must register it to vLLM.
This page provides detailed instructions on how to do so.
## Built-in models
To add a model directly to the vLLM library, start by forking our [GitHub repository](https://github.com/vllm-project/vllm) and then [build it from source](#build-from-source).
This gives you the ability to modify the codebase and test your model.
After you have implemented your model (see [tutorial](#new-model-basic)), put it into the <gh-dir:vllm/model_executor/models> directory.
Then, add your model class to `_VLLM_MODELS` in <gh-file:vllm/model_executor/models/registry.py> so that it is automatically registered upon importing vLLM.
You should also include an example HuggingFace repository for this model in <gh-file:tests/models/registry.py> to run the unit tests.
Finally, update the [Supported Models](#supported-models) documentation page to promote your model!
```{important}
The list of models in each section should be maintained in alphabetical order.
```
## Out-of-tree models
You can load an external model using a plugin without modifying the vLLM codebase.
```{seealso}
[vLLM's Plugin System](#plugin-system)
```
To register the model, use the following code:
```python
from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
```
If your model imports modules that initialize CUDA, consider lazy-importing it to avoid errors like `RuntimeError: Cannot re-initialize CUDA in forked subprocess`:
```python
from vllm import ModelRegistry
ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM")
```
```{important}
If your model is a multimodal model, ensure the model class implements the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
Read more about that [here](#enabling-multimodal-inputs).
```
```{note}
Although you can directly put these code snippets in your script using `vllm.LLM`, the recommended way is to place these snippets in a vLLM plugin. This ensures compatibility with various vLLM features like distributed inference and the API server.
```

View File

@ -1,6 +1,8 @@
# Implementation
(design-automatic-prefix-caching)=
The core idea of PagedAttention is to partition the KV cache of each request into KV Blocks. Each block contains the attention keys and values for a fixed number of tokens. The PagedAttention algorithm allows these blocks to be stored in non-contiguous physical memory so that we can eliminate memory fragmentation by allocating the memory on demand.
# Automatic Prefix Caching
The core idea of [PagedAttention](#design-paged-attention) is to partition the KV cache of each request into KV Blocks. Each block contains the attention keys and values for a fixed number of tokens. The PagedAttention algorithm allows these blocks to be stored in non-contiguous physical memory so that we can eliminate memory fragmentation by allocating the memory on demand.
To automatically cache the KV cache, we utilize the following key observation: Each KV block can be uniquely identified by the tokens within the block and the tokens in the prefix before the block.

View File

@ -1,3 +1,5 @@
(design-paged-attention)=
# vLLM Paged Attention
- Currently, vLLM utilizes its own implementation of a multi-head query

View File

@ -1,6 +1,7 @@
# Offline Inference
```{toctree}
:caption: Contents
:maxdepth: 1
llm

View File

@ -1,13 +1,13 @@
(apc)=
(automatic-prefix-caching)=
# Introduction
# Automatic Prefix Caching
## What is Automatic Prefix Caching
## Introduction
Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part.
```{note}
Technical details on how vLLM implements APC are in the next page.
Technical details on how vLLM implements APC can be found [here](#design-automatic-prefix-caching).
```
## Enabling APC in vLLM

View File

@ -32,7 +32,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
* - Feature
- [CP](#chunked-prefill)
- [APC](#apc)
- [APC](#automatic-prefix-caching)
- [LoRA](#lora-adapter)
- <abbr title="Prompt Adapter">prmpt adptr</abbr>
- [SD](#spec_decode)
@ -64,7 +64,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
-
-
-
* - [APC](#apc)
* - [APC](#automatic-prefix-caching)
- ✅
-
-
@ -345,7 +345,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
- ✅
- ✅
- ✅
* - [APC](#apc)
* - [APC](#automatic-prefix-caching)
- [](gh-issue:3687)
- ✅
- ✅

View File

@ -41,13 +41,13 @@ Key abstractions for disaggregated prefilling:
Here is a figure illustrating how the above 3 abstractions are organized:
```{image} /assets/usage/disagg_prefill/abstraction.jpg
```{image} /assets/features/disagg_prefill/abstraction.jpg
:alt: Disaggregated prefilling abstractions
```
The workflow of disaggregated prefilling is as follows:
```{image} /assets/usage/disagg_prefill/overview.jpg
```{image} /assets/features/disagg_prefill/overview.jpg
:alt: Disaggregated prefilling workflow
```

View File

@ -0,0 +1,19 @@
(quantization-index)=
# Quantization
Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices.
```{toctree}
:caption: Contents
:maxdepth: 1
supported_hardware
auto_awq
bnb
gguf
int8
fp8
fp8_e5m2_kvcache
fp8_e4m3_kvcache
```

View File

@ -1,6 +1,6 @@
(supported-hardware-for-quantization)=
(quantization-supported-hardware)=
# Supported Hardware for Quantization Kernels
# Supported Hardware
The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM:
@ -120,12 +120,12 @@ The table below shows the compatibility of various quantization implementations
- ✗
```
## Notes:
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
- "✅︎" indicates that the quantization method is supported on the specified hardware.
- "✗" indicates that the quantization method is not supported on the specified hardware.
Please note that this compatibility chart may be subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods.
```{note}
This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods.
For the most up-to-date information on hardware support and quantization methods, please refer to <gh-dir:vllm/model_executor/layers/quantization> or consult with the vLLM development team.
```

View File

@ -79,6 +79,9 @@ serving/metrics
serving/integrations
serving/tensorizer
serving/runai_model_streamer
serving/engine_args
serving/env_vars
serving/usage_stats
```
```{toctree}
@ -88,53 +91,28 @@ serving/runai_model_streamer
models/supported_models
models/generative_models
models/pooling_models
models/adding_model
models/enabling_multimodal_inputs
```
```{toctree}
:caption: Usage
:caption: Features
:maxdepth: 1
usage/lora
usage/multimodal_inputs
usage/tool_calling
usage/structured_outputs
usage/spec_decode
usage/compatibility_matrix
usage/performance
usage/engine_args
usage/env_vars
usage/usage_stats
usage/disagg_prefill
```
```{toctree}
:caption: Quantization
:maxdepth: 1
quantization/supported_hardware
quantization/auto_awq
quantization/bnb
quantization/gguf
quantization/int8
quantization/fp8
quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache
```
```{toctree}
:caption: Automatic Prefix Caching
:maxdepth: 1
automatic_prefix_caching/apc
automatic_prefix_caching/details
features/quantization/index
features/lora
features/multimodal_inputs
features/tool_calling
features/structured_outputs
features/automatic_prefix_caching
features/disagg_prefill
features/spec_decode
features/compatibility_matrix
```
```{toctree}
:caption: Performance
:maxdepth: 1
performance/optimization
performance/benchmarks
```
@ -148,10 +126,8 @@ community/meetups
community/sponsors
```
% API Documentation: API reference aimed at vllm library usage
```{toctree}
:caption: API Documentation
:caption: API Reference
:maxdepth: 2
dev/sampling_params
@ -160,30 +136,32 @@ dev/offline_inference/offline_index
dev/engine/engine_index
```
% Design: docs about vLLM internals
% Design Documents: Details about vLLM internals
```{toctree}
:caption: Design
:caption: Design Documents
:maxdepth: 2
design/arch_overview
design/huggingface_integration
design/plugin_system
design/input_processing/model_inputs_index
design/kernel/paged_attention
design/input_processing/model_inputs_index
design/multimodal/multimodal_index
design/automatic_prefix_caching
design/multiprocessing
```
% For Developers: contributing to the vLLM project
% Developer Guide: How to contribute to the vLLM project
```{toctree}
:caption: For Developers
:caption: Developer Guide
:maxdepth: 2
contributing/overview
contributing/profiling/profiling_index
contributing/dockerfile/dockerfile
contributing/model/index
```
# Indices and tables

View File

@ -1,155 +0,0 @@
(adding-a-new-model)=
# Adding a New Model
This document provides a high-level guide on integrating a [HuggingFace Transformers](https://github.com/huggingface/transformers) model into vLLM.
```{note}
The complexity of adding a new model depends heavily on the model's architecture.
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
```
```{note}
By default, vLLM models do not support multi-modal inputs. To enable multi-modal support,
please follow [this guide](#enabling-multimodal-inputs) after implementing the model here.
```
```{tip}
If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our [GitHub](https://github.com/vllm-project/vllm/issues) repository.
We will be happy to help you out!
```
## 0. Fork the vLLM repository
Start by forking our [GitHub] repository and then [build it from source](#build-from-source).
This gives you the ability to modify the codebase and test your model.
```{tip}
If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below.
```
## 1. Bring your model code
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the <gh-dir:vllm/model_executor/models> directory.
For instance, vLLM's [OPT model](gh-file:vllm/model_executor/models/opt.py) was adapted from the HuggingFace's [modeling_opt.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py) file.
```{warning}
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
```
## 2. Make your code compatible with vLLM
To ensure compatibility with vLLM, your model must meet the following requirements:
### Initialization Code
All vLLM modules within the model must include a `prefix` argument in their constructor. This `prefix` is typically the full name of the module in the model's state dictionary and is crucial for:
- Runtime support: vLLM's attention operators are registered in a model's state by their full names. Each attention operator must have a unique prefix as its layer name to avoid conflicts.
- Non-uniform quantization support: A quantized checkpoint can selectively quantize certain layers while keeping others in full precision. By providing the `prefix` during initialization, vLLM can match the current layer's `prefix` with the quantization configuration to determine if the layer should be initialized in quantized mode.
The initialization code should look like this:
```python
from torch import nn
from vllm.config import VllmConfig
from vllm.attention import Attention
class MyAttention(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.attn = Attention(prefix=f"{prefix}.attn")
class MyDecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.self_attn = MyAttention(prefix=f"{prefix}.self_attn")
class MyModel(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str):
super().__init__()
self.layers = nn.ModuleList(
[MyDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") for i in range(vllm_config.model_config.hf_config.num_hidden_layers)]
)
class MyModelForCausalLM(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model = MyModel(vllm_config, prefix=f"{prefix}.model")
```
### Computation Code
Rewrite the {meth}`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat `input_ids` and `positions` as flattened tensors with a single batch size dimension, without a max-sequence length dimension.
```python
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
```
```{note}
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
```
For reference, check out our [Llama implementation](gh-file:vllm/model_executor/models/llama.py). vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out <gh-dir:vllm/model_executor/models> for more examples.
## 3. (Optional) Implement tensor parallelism and quantization support
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
For the embedding layer, you can simply replace {class}`torch.nn.Embedding` with {code}`VocabParallelEmbedding`. For the output LM head, you can use {code}`ParallelLMHead`.
When it comes to the linear layers, we provide the following options to parallelize them:
- {code}`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
- {code}`RowParallelLinear`: The input tensor is partitioned along the hidden dimension. The weight matrix is partitioned along the rows (input dimension). An *all-reduce* operation is performed after the matrix multiplication to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layer.
- {code}`ColumnParallelLinear`: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer.
- {code}`MergedColumnParallelLinear`: Column-parallel linear that merges multiple {code}`ColumnParallelLinear` operators. Typically used for the first FFN layer with weighted activation functions (e.g., SiLU). This class handles the sharded weight loading logic of multiple weight matrices.
- {code}`QKVParallelLinear`: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the key/value heads properly. This class handles the weight loading and replication of the weight matrices.
Note that all the linear layers above take {code}`linear_method` as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
## 4. Implement the weight loading logic
You now need to implement the {code}`load_weights` method in your {code}`*ForCausalLM` class.
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model. Specifically, for {code}`MergedColumnParallelLinear` and {code}`QKVParallelLinear` layers, if the original model has separated weight matrices, you need to load the different parts separately.
## 5. Register your model
Finally, register your {code}`*ForCausalLM` class to the {code}`_VLLM_MODELS` in <gh-file:vllm/model_executor/models/registry.py>.
## 6. Out-of-Tree Model Integration
You can integrate a model without modifying the vLLM codebase. Steps 2, 3, and 4 are still required, but you can skip steps 1 and 5. Instead, write a plugin to register your model. For general introduction of the plugin system, see [plugin-system](#plugin-system).
To register the model, use the following code:
```python
from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
```
If your model imports modules that initialize CUDA, consider lazy-importing it to avoid errors like {code}`RuntimeError: Cannot re-initialize CUDA in forked subprocess`:
```python
from vllm import ModelRegistry
ModelRegistry.register_model("YourModelForCausalLM", "your_code:YourModelForCausalLM")
```
```{important}
If your model is a multimodal model, ensure the model class implements the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
Read more about that [here](#enabling-multimodal-inputs).
```
```{note}
Although you can directly put these code snippets in your script using `vllm.LLM`, the recommended way is to place these snippets in a vLLM plugin. This ensures compatibility with various vLLM features like distributed inference and the API server.
```

View File

@ -37,7 +37,7 @@ print(output)
If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported.
````
Otherwise, please refer to [Adding a New Model](#adding-a-new-model) and [Enabling Multimodal Inputs](#enabling-multimodal-inputs) for instructions on how to implement your model in vLLM.
Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM.
Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support.
### ModelScope

View File

@ -1,6 +1,6 @@
(performance)=
(optimization-and-tuning)=
# Performance and Tuning
# Optimization and Tuning
## Preemption

View File

@ -217,7 +217,7 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
We support both [Vision](https://platform.openai.com/docs/guides/vision)- and
[Audio](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in)-related parameters;
see our [Multimodal Inputs](../usage/multimodal_inputs.md) guide for more information.
see our [Multimodal Inputs](#multimodal-inputs) guide for more information.
- *Note: `image_url.detail` parameter is not supported.*
Code example: <gh-file:examples/openai_chat_completion_client.py>

View File

@ -430,7 +430,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "

View File

@ -644,7 +644,7 @@ class ModelConfig:
self.use_async_output_proc = False
return
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
from vllm.platforms import current_platform
if not current_platform.is_async_output_supported(self.enforce_eager):
@ -665,7 +665,7 @@ class ModelConfig:
if self.runner_type == "pooling":
self.use_async_output_proc = False
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if speculative_config:
logger.warning("Async output processing is not supported with"
@ -2064,7 +2064,7 @@ class LoRAConfig:
model_config.quantization)
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if scheduler_config.chunked_prefill_enabled:
logger.warning("LoRA with chunked prefill is still experimental "

View File

@ -1148,7 +1148,7 @@ class EngineArgs:
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if self.num_scheduler_steps > 1:
if speculative_config is not None:

View File

@ -65,7 +65,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
@staticmethod
@functools.lru_cache
def _log_prompt_logprob_unsupported_warning_once():
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
logger.warning(
"Prompt logprob is not supported by multi step workers. "

View File

@ -22,7 +22,7 @@ class CPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
assert self.lora_config is None, "cpu backend doesn't support LoRA"

View File

@ -50,7 +50,7 @@ class CpuPlatform(Platform):
import vllm.envs as envs
from vllm.utils import GiB_bytes
model_config = vllm_config.model_config
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if not model_config.enforce_eager:
logger.warning(

View File

@ -108,7 +108,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
return spec_decode_worker
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""Worker which implements speculative decoding.

View File

@ -58,7 +58,7 @@ logger = init_logger(__name__)
# Exception strings for non-implemented encoder/decoder scenarios
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
STR_NOT_IMPL_ENC_DEC_SWA = \

View File

@ -822,7 +822,7 @@ def _pythonize_sampler_output(
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
# (Check for Guided Decoding)
if seq_group.sampling_params.logits_processors:

View File

@ -13,7 +13,7 @@ def assert_enc_dec_mr_supported_scenario(
a supported scenario.
'''
# Reminder: Please update docs/source/usage/compatibility_matrix.md
# Reminder: Please update docs/source/features/compatibility_matrix.md
# If the feature combo become valid
if enc_dec_mr.cache_config.enable_prefix_caching: