[Feature] Estimate max-model-len use available KV cache memory (#16168)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
rongfu.leng 2025-04-09 10:12:51 +08:00 committed by GitHub
parent 4e9cf8c1dd
commit 4716377fbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 106 additions and 5 deletions

View File

@ -3,14 +3,16 @@
import pytest
import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.utils import GiB_bytes, sha256
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
FreeKVCacheBlockQueue, KVCacheBlock,
PrefixCachingMetrics,
estimate_max_model_len,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens,
@ -426,3 +428,45 @@ def test_unify_kv_cache_configs():
]
with pytest.raises(AssertionError):
unify_kv_cache_configs(diff_kv_cache_config)
@pytest.mark.parametrize(
("model_id", "max_model_len", "want_estimated_max_len"), [
("Qwen/Qwen1.5-7B", 16385, 16384),
("Qwen/Qwen1.5-7B", 16383, 16383),
])
def test_estimate_max_model_len(model_id, max_model_len,
want_estimated_max_len):
# Create a VllmConfig
model_config = ModelConfig(
model_id,
task="generate",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
max_model_len=max_model_len,
)
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)
vllm_config = VllmConfig(
model_config=model_config,
scheduler_config=scheduler_config,
)
# Create KV cache specs
kv_cache_spec = {}
for i in range(32):
layer_name = f"layer_{i}"
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=16,
num_kv_heads=32,
head_size=128,
dtype=torch.float16,
use_mla=False,
)
# Estimate the maximum model length, 16384 model_len need 8GB
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
8 * GiB_bytes)
assert estimated_max_len == want_estimated_max_len

View File

@ -8,7 +8,7 @@ from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.utils import GiB_bytes, sha256
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
@ -459,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int,
return ret
def estimate_max_model_len(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> int:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The estimated maximum model length that can fit in the available memory.
"""
# Define a function to check if a given model length fits in memory
def fits_in_memory(model_len: int) -> bool:
# Modify the max_model_len for this calculation
vllm_config.model_config.max_model_len = model_len
# Calculate memory needed for the given model length
memory_needed = sum(
(layer_spec.max_memory_usage_bytes(vllm_config)
for layer_spec in kv_cache_spec.values()),
start=0,
)
return memory_needed <= available_memory
# Binary search for the maximum model length
current_max = vllm_config.model_config.max_model_len
left, right = 1, current_max
# If even the smallest model length doesn't fit, return 0
if not fits_in_memory(left):
return 0
# Binary search for the maximum model length that fits
result = 1
while left <= right:
mid = (left + right) // 2
if fits_in_memory(mid):
result = mid
left = mid + 1
else:
right = mid - 1
return result
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
@ -486,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
if needed_memory > available_memory:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
available_memory)
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = " Based on the available memory,"
f" the estimated maximum model length is {estimated_max_len}."
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV "
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try "
f"increasing `gpu_memory_utilization` or decreasing "
f"memory ({available_memory/GiB_bytes:.2f} GiB)."
f"{estimated_msg} "
f" Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")