[Feature] Estimate max-model-len use available KV cache memory (#16168)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
parent
4e9cf8c1dd
commit
4716377fbc
@ -3,14 +3,16 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.utils import GiB_bytes, sha256
|
||||
# disable yapf here as it formats differently than isort such that both fail
|
||||
# yapf: disable
|
||||
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
PrefixCachingMetrics,
|
||||
estimate_max_model_len,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens,
|
||||
hash_request_tokens,
|
||||
@ -426,3 +428,45 @@ def test_unify_kv_cache_configs():
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
unify_kv_cache_configs(diff_kv_cache_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "max_model_len", "want_estimated_max_len"), [
|
||||
("Qwen/Qwen1.5-7B", 16385, 16384),
|
||||
("Qwen/Qwen1.5-7B", 16383, 16383),
|
||||
])
|
||||
def test_estimate_max_model_len(model_id, max_model_len,
|
||||
want_estimated_max_len):
|
||||
# Create a VllmConfig
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="generate",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
|
||||
# Create KV cache specs
|
||||
kv_cache_spec = {}
|
||||
for i in range(32):
|
||||
layer_name = f"layer_{i}"
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
)
|
||||
# Estimate the maximum model length, 16384 model_len need 8GB
|
||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||
8 * GiB_bytes)
|
||||
assert estimated_max_len == want_estimated_max_len
|
||||
|
@ -8,7 +8,7 @@ from typing import Any, Callable, NamedTuple, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256
|
||||
from vllm.utils import GiB_bytes, sha256
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec)
|
||||
@ -459,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int,
|
||||
return ret
|
||||
|
||||
|
||||
def estimate_max_model_len(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> int:
|
||||
"""
|
||||
Estimates the maximum model length that can fit in the available memory
|
||||
using binary search.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The estimated maximum model length that can fit in the available memory.
|
||||
"""
|
||||
|
||||
# Define a function to check if a given model length fits in memory
|
||||
def fits_in_memory(model_len: int) -> bool:
|
||||
# Modify the max_model_len for this calculation
|
||||
vllm_config.model_config.max_model_len = model_len
|
||||
# Calculate memory needed for the given model length
|
||||
memory_needed = sum(
|
||||
(layer_spec.max_memory_usage_bytes(vllm_config)
|
||||
for layer_spec in kv_cache_spec.values()),
|
||||
start=0,
|
||||
)
|
||||
return memory_needed <= available_memory
|
||||
|
||||
# Binary search for the maximum model length
|
||||
current_max = vllm_config.model_config.max_model_len
|
||||
left, right = 1, current_max
|
||||
|
||||
# If even the smallest model length doesn't fit, return 0
|
||||
if not fits_in_memory(left):
|
||||
return 0
|
||||
|
||||
# Binary search for the maximum model length that fits
|
||||
result = 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
if fits_in_memory(mid):
|
||||
result = mid
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid - 1
|
||||
return result
|
||||
|
||||
|
||||
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int):
|
||||
@ -486,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
|
||||
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
|
||||
|
||||
if needed_memory > available_memory:
|
||||
# Estimate the maximum model length that can fit in the available memory
|
||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
estimated_msg = ""
|
||||
if estimated_max_len > 0:
|
||||
estimated_msg = " Based on the available memory,"
|
||||
f" the estimated maximum model length is {estimated_max_len}."
|
||||
|
||||
raise ValueError(
|
||||
f"To serve at least one request with the models's max seq len "
|
||||
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV "
|
||||
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
|
||||
f"cache is needed, which is larger than the available KV cache "
|
||||
f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try "
|
||||
f"increasing `gpu_memory_utilization` or decreasing "
|
||||
f"memory ({available_memory/GiB_bytes:.2f} GiB)."
|
||||
f"{estimated_msg} "
|
||||
f" Try increasing `gpu_memory_utilization` or decreasing "
|
||||
f"`max_model_len` when initializing the engine.")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user