[V1] Revert the default max_num_seqs
to V0 values for most hardware (#16158)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
027b204ff1
commit
66d433b94f
@ -156,10 +156,3 @@ vLLM V1 is currently optimized for decoder-only transformers. Models requiring
|
||||
cross-attention between separate encoder and decoder are not yet supported (e.g., `BartForConditionalGeneration`, `MllamaForConditionalGeneration`).
|
||||
|
||||
For a complete list of supported models, see the [list of supported models](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
||||
|
||||
## Frequently Asked Questions
|
||||
|
||||
**I'm using vLLM V1 and I'm getting CUDA OOM errors. What should I do?**
|
||||
The default `max_num_seqs` has been raised from `256` in V0 to `1024` in V1. If you encounter CUDA OOM only when using V1 engine, try setting a lower value of `max_num_seqs` or `gpu_memory_utilization`.
|
||||
|
||||
On the other hand, if you get an error about insufficient memory for the cache blocks, you should increase `gpu_memory_utilization` as this indicates that your GPU has sufficient memory but you're not allocating enough to vLLM for KV cache blocks.
|
||||
|
@ -64,15 +64,17 @@ def test_defaults_with_usage_context():
|
||||
# For H100 and H200, we use larger default values.
|
||||
default_llm_tokens = 16384
|
||||
default_server_tokens = 8192
|
||||
default_max_num_seqs = 1024
|
||||
else:
|
||||
default_llm_tokens = 8192
|
||||
default_server_tokens = 2048
|
||||
default_max_num_seqs = 256
|
||||
|
||||
assert vllm_config.scheduler_config.max_num_seqs == 1024
|
||||
assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
|
||||
assert vllm_config.scheduler_config.max_num_batched_tokens == default_llm_tokens # noqa: E501
|
||||
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.OPENAI_API_SERVER)
|
||||
assert vllm_config.scheduler_config.max_num_seqs == 1024
|
||||
assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
|
||||
assert vllm_config.scheduler_config.max_num_batched_tokens == default_server_tokens # noqa: E501
|
||||
|
@ -1666,12 +1666,14 @@ class EngineArgs:
|
||||
UsageContext.LLM_CLASS: 16384,
|
||||
UsageContext.OPENAI_API_SERVER: 8192,
|
||||
}
|
||||
default_max_num_seqs = 1024
|
||||
else:
|
||||
# TODO(woosuk): Tune the default values for other hardware.
|
||||
default_max_num_batched_tokens = {
|
||||
UsageContext.LLM_CLASS: 8192,
|
||||
UsageContext.OPENAI_API_SERVER: 2048,
|
||||
}
|
||||
default_max_num_seqs = 256
|
||||
|
||||
use_context_value = usage_context.value if usage_context else None
|
||||
if (self.max_num_batched_tokens is None
|
||||
@ -1682,7 +1684,6 @@ class EngineArgs:
|
||||
"Setting max_num_batched_tokens to %d for %s usage context.",
|
||||
self.max_num_batched_tokens, use_context_value)
|
||||
|
||||
default_max_num_seqs = 1024
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = default_max_num_seqs
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user