From 4716377fbc1887f27732b3816bd010a6809e41bc Mon Sep 17 00:00:00 2001 From: "rongfu.leng" Date: Wed, 9 Apr 2025 10:12:51 +0800 Subject: [PATCH] [Feature] Estimate max-model-len use available KV cache memory (#16168) Signed-off-by: rongfu.leng --- tests/v1/core/test_kv_cache_utils.py | 46 +++++++++++++++++++- vllm/v1/core/kv_cache_utils.py | 65 ++++++++++++++++++++++++++-- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 51836644..d2b04c15 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -3,14 +3,16 @@ import pytest import torch +from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.utils import sha256 +from vllm.utils import GiB_bytes, sha256 # disable yapf here as it formats differently than isort such that both fail # yapf: disable from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, + estimate_max_model_len, generate_block_hash_extra_keys, hash_block_tokens, hash_request_tokens, @@ -426,3 +428,45 @@ def test_unify_kv_cache_configs(): ] with pytest.raises(AssertionError): unify_kv_cache_configs(diff_kv_cache_config) + + +@pytest.mark.parametrize( + ("model_id", "max_model_len", "want_estimated_max_len"), [ + ("Qwen/Qwen1.5-7B", 16385, 16384), + ("Qwen/Qwen1.5-7B", 16383, 16383), + ]) +def test_estimate_max_model_len(model_id, max_model_len, + want_estimated_max_len): + # Create a VllmConfig + model_config = ModelConfig( + model_id, + task="generate", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + max_model_len=max_model_len, + ) + scheduler_config = SchedulerConfig(max_num_batched_tokens=32768) + + vllm_config = VllmConfig( + model_config=model_config, + scheduler_config=scheduler_config, + ) + + # Create KV cache specs + kv_cache_spec = {} + for i in range(32): + layer_name = f"layer_{i}" + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=16, + num_kv_heads=32, + head_size=128, + dtype=torch.float16, + use_mla=False, + ) + # Estimate the maximum model length, 16384 model_len need 8GB + estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, + 8 * GiB_bytes) + assert estimated_max_len == want_estimated_max_len diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index afcf7e34..bd0e01d0 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -8,7 +8,7 @@ from typing import Any, Callable, NamedTuple, Optional from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import sha256 +from vllm.utils import GiB_bytes, sha256 from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, SlidingWindowSpec) @@ -459,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int, return ret +def estimate_max_model_len(vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> int: + """ + Estimates the maximum model length that can fit in the available memory + using binary search. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The kv cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The estimated maximum model length that can fit in the available memory. + """ + + # Define a function to check if a given model length fits in memory + def fits_in_memory(model_len: int) -> bool: + # Modify the max_model_len for this calculation + vllm_config.model_config.max_model_len = model_len + # Calculate memory needed for the given model length + memory_needed = sum( + (layer_spec.max_memory_usage_bytes(vllm_config) + for layer_spec in kv_cache_spec.values()), + start=0, + ) + return memory_needed <= available_memory + + # Binary search for the maximum model length + current_max = vllm_config.model_config.max_model_len + left, right = 1, current_max + + # If even the smallest model length doesn't fit, return 0 + if not fits_in_memory(left): + return 0 + + # Binary search for the maximum model length that fits + result = 1 + while left <= right: + mid = (left + right) // 2 + if fits_in_memory(mid): + result = mid + left = mid + 1 + else: + right = mid - 1 + return result + + def check_enough_kv_cache_memory(vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int): @@ -486,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, needed_memory += layer_spec.max_memory_usage_bytes(vllm_config) if needed_memory > available_memory: + # Estimate the maximum model length that can fit in the available memory + estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, + available_memory) + estimated_msg = "" + if estimated_max_len > 0: + estimated_msg = " Based on the available memory," + f" the estimated maximum model length is {estimated_max_len}." + raise ValueError( f"To serve at least one request with the models's max seq len " - f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV " + f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV " f"cache is needed, which is larger than the available KV cache " - f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try " - f"increasing `gpu_memory_utilization` or decreasing " + f"memory ({available_memory/GiB_bytes:.2f} GiB)." + f"{estimated_msg} " + f" Try increasing `gpu_memory_utilization` or decreasing " f"`max_model_len` when initializing the engine.")