[Doc] Improve OOM troubleshooting (#16704)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-04-16 18:29:48 +08:00 committed by GitHub
parent 7168920491
commit facbe2a114
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 4 deletions

View File

@ -24,7 +24,7 @@ To isolate the model downloading and loading issue, you can use the `--load-form
## Out of memory ## Out of memory
If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider [using tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism. If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider adopting [these options](#reducing-memory-usage) to reduce the memory consumption.
## Generation quality changed ## Generation quality changed

View File

@ -59,6 +59,8 @@ model = LLM(
Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM. Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM.
(reducing-memory-usage)=
### Reducing memory usage ### Reducing memory usage
Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem. Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem.
@ -81,6 +83,12 @@ before initializing vLLM. Otherwise, you may run into an error like `RuntimeErro
To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable. To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable.
::: :::
:::{note}
With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism).
You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
:::
#### Quantization #### Quantization
Quantized models take less memory at the cost of lower precision. Quantized models take less memory at the cost of lower precision.
@ -103,6 +111,39 @@ llm = LLM(model="adept/fuyu-8b",
max_num_seqs=2) max_num_seqs=2)
``` ```
#### Reduce CUDA Graphs
By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU.
:::{important}
CUDA graph capture takes up more memory in V1 than in V0.
:::
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
```python
from vllm import LLM
from vllm.config import CompilationConfig, CompilationLevel
llm = LLM(
model="meta-llama/Llama-3.1-8B-Instruct",
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
# By default, it goes up to max_num_seqs
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
),
)
```
You can disable graph capturing completely via the `enforce_eager` flag:
```python
from vllm import LLM
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
enforce_eager=True)
```
#### Adjust cache size #### Adjust cache size
If you run out of CPU RAM, try the following options: If you run out of CPU RAM, try the following options:
@ -110,16 +151,25 @@ If you run out of CPU RAM, try the following options:
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB). - (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
#### Disable unused modalities #### Multi-modal input limits
You can disable unused modalities (except for text) by setting its limit to zero. You can allow a smaller number of multi-modal items per prompt to reduce the memory footprint of the model:
```python
from vllm import LLM
# Accept up to 3 images and 1 video per prompt
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
limit_mm_per_prompt={"image": 3, "video": 1})
```
You can go a step further and disable unused modalities completely by setting its limit to zero.
For example, if your application only accepts image input, there is no need to allocate any memory for videos. For example, if your application only accepts image input, there is no need to allocate any memory for videos.
```python ```python
from vllm import LLM from vllm import LLM
# Accept images but not videos # Accept any number of images but no videos
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
limit_mm_per_prompt={"video": 0}) limit_mm_per_prompt={"video": 0})
``` ```