diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index e23b8718..039b5e73 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, check_answers(indices, answer, test_texts) -def prep_prompts(batch_size: int): +def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): """ Generate prompts which a bunch of assignments, then asking for the value of one of them. The prompt is just under 10k tokens; sliding window is 4k so the answer is outside sliding window, but should still be correct. + + Args: + batch_size: number of prompts to generate + ln_range: an argument to control the length of the prompt """ prompts: list[str] = [] answer: list[int] = [] @@ -145,7 +149,7 @@ def prep_prompts(batch_size: int): indices.append(idx) prompt = "```python\n# We set a number of variables, " + \ f"x{idx} will be important later\n" - ln = random.randint(800, 1100) + ln = random.randint(*ln_range) for k in range(30, ln): v = random.randint(10, 99) if k == idx: @@ -157,7 +161,10 @@ def prep_prompts(batch_size: int): return prompts, answer, indices -def check_answers(indices: list[int], answer: list[int], outputs: list[str]): +def check_answers(indices: list[int], + answer: list[int], + outputs: list[str], + accept_rate: float = 0.7): answer2 = [int(text[0:2].strip()) for text in outputs] print(list(zip(indices, zip(answer, answer2)))) numok = 0 @@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]): numok += 1 frac_ok = numok / len(answer) print(f"Num OK: {numok}/{len(answer)} {frac_ok}") - assert frac_ok > 0.7 + assert frac_ok >= accept_rate def check_window(prompts: list[str]): diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 72a1874f..80dd275a 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -4,6 +4,7 @@ from typing import Optional import pytest +import torch from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams @@ -12,6 +13,8 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_block_tokens) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) def make_request(request_id, @@ -39,13 +42,23 @@ def make_request(request_id, ) +def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=num_blocks, + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + + @pytest.mark.parametrize("hash_algo", ["sha256", "hash"]) def test_prefill(hash_algo): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 11), max_model_len=8192, - sliding_window=None, enable_caching=True, caching_hash_algo=hash_algo, num_preallocate_tokens=16, @@ -67,12 +80,12 @@ def test_prefill(hash_algo): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] # Check full block metadata parent_block_hash = None - for block_id in (0, 1, 2): - block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16]) + for block_id in (1, 2, 3): + block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) assert manager.block_pool.blocks[block_id].block_hash == block_hash @@ -80,7 +93,7 @@ def test_prefill(hash_algo): parent_block_hash = block_hash.hash_value # Check partial/preallocated block metadata - for block_id in (3, 4): + for block_id in (4, 5): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -90,11 +103,11 @@ def test_prefill(hash_algo): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [5, 6] + assert [b.block_id for b in blocks] == [6, 7] for block in computed_blocks: assert block.ref_cnt == 2 @@ -107,14 +120,14 @@ def test_prefill(hash_algo): # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be - # [unallocated (7, 8, 9)] - # [unique_req0 (4, 3)] - # [unique_req1 (6, 5)] - # [common (2, 1, 0)] + # [unallocated (8, 9, 10)] + # [unique_req0 (5, 4)] + # [unique_req1 (7, 6)] + # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0] + ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) @@ -122,11 +135,11 @@ def test_prefill(hash_algo): req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [7, 8] + assert [b.block_id for b in blocks] == [8, 9] # Although we only have 5 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -148,7 +161,7 @@ def test_prefill(hash_algo): assert num_computed_tokens == 0 blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) # This block ID order also checks the eviction order. - assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] + assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1] assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -162,10 +175,8 @@ def test_prefill_plp(): 3. Schedule plp request; no hit should occur; validate blocks ''' manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 11), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -186,13 +197,13 @@ def test_prefill_plp(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] req0_block_hashes = [b.block_hash for b in blocks] # Check full block metadata parent_block_hash = None - for block_id in (0, 1, 2): - block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16]) + for block_id in (1, 2, 3): + block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) assert manager.block_pool.blocks[block_id].block_hash == block_hash @@ -200,7 +211,7 @@ def test_prefill_plp(): parent_block_hash = block_hash.hash_value # Check partial/preallocated block metadata - for block_id in (3, 4): + for block_id in (4, 5): assert manager.block_pool.blocks[block_id].block_hash is None assert manager.block_pool.blocks[block_id].ref_cnt == 1 @@ -211,11 +222,11 @@ def test_prefill_plp(): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert [b.block_id for b in computed_blocks] == [0, 1, 2] + assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) - assert [b.block_id for b in blocks] == [5, 6] + assert [b.block_id for b in blocks] == [6, 7] for block in computed_blocks: assert block.ref_cnt == 2 @@ -228,14 +239,14 @@ def test_prefill_plp(): # All blocks should be available. assert manager.block_pool.free_block_queue.num_free_blocks == 10 # The order should be - # [unallocated (7, 8, 9)] - # [unique_req0 (4, 3)] - # [unique_req1 (6, 5)] - # [common (2, 1, 0)] + # [unallocated (8, 9, 10)] + # [unique_req0 (5, 4)] + # [unique_req1 (7, 6)] + # [common (3, 2, 1)] assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0] + ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1] # Request #2 is a prompt-logprobs request: # NO cache hit in the common prefix; duplicates request #0 cached blocks @@ -251,7 +262,7 @@ def test_prefill_plp(): block_ids = [b.block_id for b in blocks] # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks] == req0_block_hashes - assert block_ids != [0, 1, 2, 3, 4] + assert block_ids != [1, 2, 3, 4, 5] # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. @@ -263,10 +274,8 @@ def test_prefill_plp(): def test_decode(): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 11), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -282,7 +291,7 @@ def test_decode(): assert not computed_blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -316,10 +325,8 @@ def test_decode(): def test_evict(): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 11), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -350,15 +357,15 @@ def test_evict(): assert [ b.block_id for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7] + ] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8] # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert [b.block_id for b in computed_blocks] == [0, 1] + assert [b.block_id for b in computed_blocks] == [1, 2] assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, computed_blocks) - assert [b.block_id for b in blocks] == [6, 5] + assert [b.block_id for b in blocks] == [7, 6] assert manager.block_pool.free_block_queue.num_free_blocks == 6 @@ -369,10 +376,8 @@ def test_hash_block_correct_reuse(): """ block_size = 16 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=1, + make_kv_cache_config(16, 2), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -408,10 +413,8 @@ def test_computed_blocks_not_evicted(): """ block_size = 16 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=2, + make_kv_cache_config(block_size, 3), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -424,7 +427,7 @@ def test_computed_blocks_not_evicted(): assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) assert len(blocks) == 1 - assert blocks[0].block_id == 0 + assert blocks[0].block_id == 1 # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) @@ -433,7 +436,7 @@ def test_computed_blocks_not_evicted(): assert num_computed_tokens == 0 blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) assert len(blocks) == 1 - assert blocks[0].block_id == 1 + assert blocks[0].block_id == 2 # Free the blocks. manager.free(req0) @@ -444,13 +447,13 @@ def test_computed_blocks_not_evicted(): req2 = make_request("2", list(range(num_tokens * 2))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks) == 1 - assert computed_blocks[0].block_id == 0 + assert computed_blocks[0].block_id == 1 assert num_computed_tokens == block_size blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, computed_blocks) assert len(blocks) == 1 - assert blocks[0].block_id == 1 + assert blocks[0].block_id == 2 def test_basic_prefix_caching_disabled(): @@ -459,10 +462,8 @@ def test_basic_prefix_caching_disabled(): """ block_size = 4 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=4, + make_kv_cache_config(block_size, 5), max_model_len=8192, - sliding_window=None, enable_caching=False, num_preallocate_tokens=0, ) @@ -502,10 +503,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): This tests that the preallocated blocks are correctly added. """ manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=10, + make_kv_cache_config(block_size, 11), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=num_preallocate_tokens, ) @@ -586,10 +585,8 @@ def test_mm_prefix_caching(): This tests that the multi-modal prefix caching is correct. """ manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 11), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -629,7 +626,7 @@ def test_mm_prefix_caching(): assert block_hashes[2].extra_keys == ("bbb", ) blocks = manager.allocate_slots(req0, 59, computed_blocks) - assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] + assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5] req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -667,10 +664,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): """ block_size = 16 manager = KVCacheManager( - block_size=block_size, - num_gpu_blocks=10, + make_kv_cache_config(block_size, 11), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -723,10 +718,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_reset_prefix_cache(): manager = KVCacheManager( - block_size=16, - num_gpu_blocks=10, + make_kv_cache_config(16, 11), max_model_len=8192, - sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -736,7 +729,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) blocks = manager.allocate_slots(req0, 55) - assert [b.block_id for b in blocks] == [0, 1, 2, 3] + assert [b.block_id for b in blocks] == [1, 2, 3, 4] unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -745,7 +738,7 @@ def test_reset_prefix_cache(): assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(computed_blocks) == 3 blocks = manager.allocate_slots(req1, 7, computed_blocks) - assert [b.block_id for b in blocks] == [4] + assert [b.block_id for b in blocks] == [5] # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 5b965665..73af7dad 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -2,12 +2,15 @@ from typing import Optional import pytest +import torch from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager @@ -66,12 +69,21 @@ def create_scheduler( model_config=model_config, cache_config=cache_config, ) + kv_cache_config = KVCacheConfig( + num_blocks=10000, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(16, 1, 1, torch.float32, False)) + ], + ) cache_config.num_gpu_blocks = 10000 return Scheduler( scheduler_config, model_config, cache_config, lora_config=None, + kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), ) diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py new file mode 100644 index 00000000..9b4ab5fa --- /dev/null +++ b/tests/v1/core/test_specialized_manager.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.specialized_manager import SlidingWindowManager +from vllm.v1.kv_cache_interface import SlidingWindowSpec + + +def test_sliding_window_possible_cached_prefix(): + sliding_window_spec = SlidingWindowSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4, + use_mla=False, + ) + + block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + manager = SlidingWindowManager(sliding_window_spec, block_pool) + + def run_one_case(block_is_cached, expect_length): + block_hash_list = [ + BlockHashType(i, ()) for i in range(len(block_is_cached)) + ] + + block_pool.cached_block_hash_to_block.clear() + + # Mock the block pool with the cached blocks + for i, (block_hash, + is_cached) in enumerate(zip(block_hash_list, block_is_cached)): + if is_cached: + block_pool.cached_block_hash_to_block[block_hash] = { + i: block_pool.blocks[i + 10] + } + + computed_blocks = manager.find_longest_cache_hit(block_hash_list) + assert len(computed_blocks) == expect_length + + assert all(block == block_pool.null_block + for block in computed_blocks[:expect_length - 2]) + for i in range(2): + if i < expect_length: + block_index = expect_length - i - 1 + assert computed_blocks[ + block_index].block_id == block_index + 10 + + run_one_case([False] * 10, 0) + run_one_case([True], 1) + run_one_case([True, False], 1) + run_one_case([True, True], 2) + run_one_case([True, True, False], 2) + run_one_case([True, True, True], 3) + run_one_case([True, True, True, False], 3) + run_one_case([ + True, True, False, True, False, False, True, True, False, True, True, + True + ], 12) + run_one_case([ + True, True, False, True, False, False, True, True, False, False, False + ], 8) + run_one_case([ + True, True, False, True, False, False, True, True, False, False, False, + True + ], 8) + + +def test_sliding_window_remove_skipped_blocks(): + sliding_window_spec = SlidingWindowSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=4, + use_mla=False, + ) + + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + + manager = SlidingWindowManager(sliding_window_spec, block_pool) + + null_block_id = block_pool.null_block.block_id + + def id_to_block_table(ids): + return [ + KVCacheBlock(id_) + if id_ != null_block_id else block_pool.null_block for id_ in ids + ] + + def assert_block_id(block_table, ids): + for block, id_ in zip(block_table, ids): + if id_ == null_block_id: + assert block == block_pool.null_block + else: + assert block.block_id == id_ + + original_block_ids = [ + 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 + ] + block_table = id_to_block_table(original_block_ids) + removed = manager.remove_skipped_blocks(block_table, 0) + assert_block_id(removed, []) + assert_block_id(block_table, original_block_ids) + + # 4 tokens are computed. Only token 0 is out of the sliding window. As + # block 1000 also contains token 1 that is in the sliding window, block 1000 + # cannot be removed. + removed = manager.remove_skipped_blocks(block_table, 4) + assert_block_id(removed, []) + assert_block_id(block_table, original_block_ids) + + # 5 tokens are computed. Token 0 & 1 are out of the sliding window. + # Block 1000 can be removed. + removed = manager.remove_skipped_blocks(block_table, 5) + assert_block_id(removed, [original_block_ids[0]]) + assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) + + # 6 tokens are computed. Token 0-2 are out of the sliding window. + # Cannot remove new block as the block 1001 is still used by token 3. + removed = manager.remove_skipped_blocks(block_table, 6) + assert_block_id(removed, []) + assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) + + # 7 tokens are computed. Token 0-3 are out of the sliding window. + # Block 1001 can be removed and block 1000 is already removed. + removed = manager.remove_skipped_blocks(block_table, 7) + assert_block_id(removed, [original_block_ids[1]]) + assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:]) + + # 11 tokens are computed. Token 0-7 are out of the sliding window. + # Block 1002 & 1003 can be removed now. Block 1003 represents a longer + # sequence, and is expected to be evicted earlier than 1002, so the order + # of removed blocks should be [1003, 1002]. + removed = manager.remove_skipped_blocks(block_table, 11) + assert_block_id(removed, [original_block_ids[3], original_block_ids[2]]) + assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:]) diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py new file mode 100644 index 00000000..a125d3fb --- /dev/null +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import pytest + +from vllm import LLM, SamplingParams + +from ...core.block.e2e.test_correctness_sliding_window import (check_answers, + prep_prompts) + + +@dataclass +class TestConfig: + sliding_window: int + ln_range: tuple[int, int] + + +model_config = { + "bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)), + "google/gemma-2-2b-it": TestConfig(4096, (400, 800)), +} + + +@pytest.mark.parametrize( + "model", + [ + "bigcode/starcoder2-3b", # sliding window only + "google/gemma-2-2b-it", # sliding window + full attention + ]) +@pytest.mark.parametrize("batch_size", [5]) +@pytest.mark.parametrize("seed", [1]) +def test_sliding_window_retrival(monkeypatch, model, batch_size, seed): + """ + The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then + asks for value of one of them (which is outside the sliding window). + If we tell it upfront which we are going to be looking for, then + it answers correctly (mostly). + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + test_config = model_config[model] + + llm = LLM(model=model) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + + prompts, answer, indices = prep_prompts(batch_size, + ln_range=test_config.ln_range) + + check_length(prompts, llm, test_config.sliding_window) + + # Fresh generation + responses = llm.generate(prompts, sampling_params) + check_answers(indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0) + + # Re-generate with the same prompts to test prefix caching + responses = llm.generate(prompts, sampling_params) + check_answers(indices, + answer, + [response.outputs[0].text for response in responses], + accept_rate=1.0) + + +def check_length(prompts: list[str], llm: LLM, sliding_window: int): + """ + Check if the prompt length is valid, i.e., longer than the sliding window + size and shorter than the model's max length. + + Args: + prompts: list of prompts + llm: LLM object + sliding_window: Sliding window size + """ + tokenizer = llm.get_tokenizer() + max_model_len = llm.llm_engine.model_config.max_model_len + assert any( + len(tokenizer.encode(prompt)) > sliding_window + for prompt in prompts), "Prompt is too short for test" + assert all( + len(tokenizer.encode(prompt)) <= max_model_len + for prompt in prompts), "Prompt is too long for test" diff --git a/vllm/config.py b/vllm/config.py index 84b9836e..96b6f84b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1116,8 +1116,7 @@ class CacheConfig: is_attention_free: Whether the model is attention-free. num_gpu_blocks_override: Number of GPU blocks to use. This overrides the profiled num_gpu_blocks if specified. Does nothing if None. - sliding_window: Sliding window size for the KV cache. Can not work with - prefix caching enabled. + sliding_window: Sliding window size for the KV cache. enable_prefix_caching: Whether to enable prefix caching. cpu_offload_gb: Size of the CPU offload buffer in GiB. """ diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 79b0c42d..43f30f71 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -27,6 +27,7 @@ class BlockPool: """ def __init__(self, num_gpu_blocks: int, enable_caching: bool): + assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 self.num_gpu_blocks = num_gpu_blocks self.enable_caching = enable_caching # All kv-cache blocks. @@ -50,6 +51,11 @@ class BlockPool: self.cached_block_hash_to_block: dict[BlockHashType, dict[ int, KVCacheBlock]] = defaultdict(dict) + # To represent a placeholder block with block_id=0. + # The ref_cnt of null_block is not maintained, needs special care to + # avoid freeing it. + self.null_block = self.free_block_queue.popleft() + def get_cached_block(self, block_hash: BlockHashType) -> Optional[KVCacheBlock]: """Get a cached block by the block hash, or None if cache miss. @@ -214,7 +220,7 @@ class BlockPool: for block in blocks: # ref_cnt=0 means this block is in the free list (i.e. eviction # candidate), so remove it. - if block.ref_cnt == 0: + if block.ref_cnt == 0 and block != self.null_block: self.free_block_queue.remove(block) block.incr_ref() @@ -228,7 +234,8 @@ class BlockPool: """ for block in ordered_blocks: block.decr_ref() - if block.ref_cnt == 0: + # null_block should not be added to the free list. + if block.ref_cnt == 0 and block != self.null_block: self.free_block_queue.append(block) def reset_prefix_cache(self) -> bool: @@ -241,10 +248,10 @@ class BlockPool: False otherwise. """ num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks()) - if num_used_blocks > 0: + if num_used_blocks != 1: # The null block is always marked as used logger.warning( "Failed to reset prefix cache because some " - "blocks (%d) are not freed yet", num_used_blocks) + "blocks (%d) are not freed yet", num_used_blocks - 1) return False # Remove all hashes so that no new blocks will hit. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 39390bab..c0f77152 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -9,6 +9,8 @@ from vllm.utils import cdiv, sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) +from vllm.v1.core.specialized_manager import get_specialized_manager +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -19,20 +21,22 @@ class KVCacheManager: def __init__( self, - block_size: int, - num_gpu_blocks: int, + kv_cache_config: KVCacheConfig, max_model_len: int, - sliding_window: Optional[int] = None, enable_caching: bool = True, caching_hash_algo: str = "builtin", num_preallocate_tokens: int = 64, log_stats: bool = False, ) -> None: - self.block_size = block_size - self.num_gpu_blocks = num_gpu_blocks + assert len(kv_cache_config.kv_cache_groups) == 1, ( + "KVCacheManager does not support hybrid models with more than 1 " + "kv cache group") + kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec + self.block_size = kv_cache_spec.block_size + self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, block_size) - self.sliding_window = sliding_window + self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) + self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash # FIXME: make prefix cache stats conditional on log_stats @@ -48,9 +52,15 @@ class KVCacheManager: # further allocation. When it uses up all the N empty blocks, it gets # N new empty blocks. self.num_preallocate_tokens = num_preallocate_tokens - self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) + self.num_preallocate_blocks = cdiv(num_preallocate_tokens, + self.block_size) - self.block_pool = BlockPool(num_gpu_blocks, enable_caching) + self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) + + self.specialized_manager = get_specialized_manager( + kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + ) # Mapping from request ID to blocks to track the blocks allocated # for each request, so that we can free the blocks when the request @@ -117,17 +127,25 @@ class KVCacheManager: self.prefix_cache_stats.requests += 1 if request.sampling_params.prompt_logprobs is None: - # Check for cache hits - computed_blocks = [] - for block_hash in block_hashes: - # block_hashes is a chain of block hashes. If a block hash - # is not in the cached_block_hash_to_id, the following - # block hashes are not computed yet for sure. - if cached_block := self.block_pool.get_cached_block( - block_hash): - computed_blocks.append(cached_block) - else: - break + if len(block_hashes) * self.block_size == request.num_tokens: + # When prompt length is divisible by the block size and all + # blocks are cached, we need to recompute the last token. This + # have to be achieved by re-computing an entire block because + # allocate_slots() assumes num_computed_tokens is always a + # multiple of the block size. To achieve this, remove the last + # block hash from the block_hashes for find_longest_cache_hit + # This limitation can potentially be removed in the future to + # slightly improve the performance. + last_block_hash = block_hashes.pop() + else: + last_block_hash = None + + computed_blocks = ( + self.specialized_manager.find_longest_cache_hit(block_hashes)) + + if last_block_hash is not None: + # Add back the last block hash if it was removed. + block_hashes.append(last_block_hash) self.prefix_cache_stats.queries += len(block_hashes) self.prefix_cache_stats.hits += len(computed_blocks) @@ -176,13 +194,24 @@ class KVCacheManager: new_computed_blocks = new_computed_blocks or [] + req_blocks = self.req_to_blocks[request.request_id] + + # Free the blocks that are skipped during the attention computation + # (e.g., tokens outside the sliding window). + # We can do this even if we cannot schedule this request due to + # insufficient free blocks. + # Should call this function before allocating new blocks to reduce + # the number of evicted blocks. + removed_blocks = self.specialized_manager.remove_skipped_blocks( + req_blocks, request.num_computed_tokens) + self.block_pool.free_blocks(removed_blocks) + # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + len(new_computed_blocks) * self.block_size) num_required_blocks = cdiv(num_computed_tokens + num_tokens, self.block_size) - req_blocks = self.req_to_blocks[request.request_id] num_new_blocks = (num_required_blocks - len(req_blocks) - len(new_computed_blocks)) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 13a3756f..34bc9369 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -9,8 +9,9 @@ from typing import Any, Callable, NamedTuple, Optional from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import sha256 -from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec, - KVCacheSpec, KVCacheTensor) +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheSpec, + KVCacheTensor, SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -483,7 +484,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, max_model_len = vllm_config.model_config.max_model_len needed_memory = 0 for layer_spec in kv_cache_spec.values(): - needed_memory += layer_spec.bytes_for_tokens(max_model_len) + needed_memory += layer_spec.max_memory_usage_bytes(vllm_config) if needed_memory > available_memory: raise ValueError( @@ -597,6 +598,33 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, return kv_cache_config +def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): + """ + Only models with one type of KV cache are supported yet. This function tries + to convert the KV cache specs to one type if the model is a hybrid model + with multiple type of KV cache. It will convert all SlidingWindowSpec to + FullAttentionSpec if both types are present. + + Args: + kv_cache_spec: The kv cache spec of each attention layer in the model + """ + + has_full_attention = any( + isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) + has_sliding_window = any( + isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values()) + if has_full_attention and has_sliding_window: + for layer_name, spec in kv_cache_spec.items(): + if isinstance(spec, SlidingWindowSpec): + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=spec.block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + dtype=spec.dtype, + use_mla=spec.use_mla, + ) + + def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: @@ -613,6 +641,7 @@ def get_kv_cache_config(vllm_config: VllmConfig, The generated KVCacheConfigs """ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) + unify_hybrid_kv_cache_specs(kv_cache_spec) if is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9e6c8e69..4d477567 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,6 +19,7 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.utils import check_stop from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -35,6 +36,7 @@ class Scheduler(SchedulerInterface): model_config: ModelConfig, cache_config: CacheConfig, lora_config: Optional[LoRAConfig], + kv_cache_config: KVCacheConfig, structured_output_manager: StructuredOutputManager, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, include_finished_set: bool = False, @@ -43,6 +45,7 @@ class Scheduler(SchedulerInterface): self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config + self.kv_cache_config = kv_cache_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager @@ -58,15 +61,11 @@ class Scheduler(SchedulerInterface): self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len - num_gpu_blocks = cache_config.num_gpu_blocks - assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, + kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - sliding_window=self.cache_config.sliding_window, - enable_caching=self.cache_config.enable_prefix_caching, + enable_caching=cache_config.enable_prefix_caching, caching_hash_algo=self.cache_config.prefix_caching_hash_algo, log_stats=self.log_stats) self.block_size = self.cache_config.block_size @@ -300,17 +299,6 @@ class Scheduler(SchedulerInterface): # `request.num_prompt_tokens` to consider the resumed requests, # which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if num_new_tokens == 0: - # This happens when prompt length is divisible by the block - # size and all blocks are cached. Now we force to recompute - # the last block. Note that we have to re-compute an entire - # block because allocate_slots() assumes num_computed_tokens - # is always a multiple of the block size. This limitation - # can potentially be removed in the future to slightly - # improve the performance. - num_computed_tokens -= self.block_size - num_new_tokens = self.block_size - computed_blocks.pop() if (0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens): num_new_tokens = ( diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py new file mode 100644 index 00000000..7a8a9836 --- /dev/null +++ b/vllm/v1/core/specialized_manager.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod + +from vllm.utils import cdiv +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, + SlidingWindowSpec) + + +class SpecializedManager(ABC): + """ + An abstract base class for specialized managers that handle the kv + cache management logic of different attention layers. + """ + + def __init__( + self, + kv_cache_spec: KVCacheSpec, + block_pool: BlockPool, + ) -> None: + """ + Initializes the SpecializedManager. + Args: + kv_cache_spec: The kv_cache_spec for this manager. + block_pool: The block pool. + """ + + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_pool = block_pool + + @abstractmethod + def find_longest_cache_hit( + self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + """ + Get the longest cache hit prefix of the blocks. If no cache hit is + found, return an empty list. + + Args: + block_hashes: The block hashes of the request. + Returns: + A list of cached blocks with skipped blocks replaced by null block. + For example, sliding window manager should return a list like + [NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and + sliding window 8. + """ + + raise NotImplementedError + + @abstractmethod + def remove_skipped_blocks(self, blocks: list[KVCacheBlock], + num_computed_tokens: int) -> list[KVCacheBlock]: + """ + Remove the blocks that are no longer needed from `blocks`. The removed + blocks should be replaced by null_block. Return the removed blocks in + eviction order, where the first returned block should be evicted first. + Don't free the removed blocks in this function. + + Args: + blocks: The list of blocks to be updated. + num_computed_tokens: The number of tokens that have been computed. + Returns: + The removed blocks in eviction order. + """ + raise NotImplementedError + + +class FullAttentionManager(SpecializedManager): + + def find_longest_cache_hit( + self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + computed_blocks: list[KVCacheBlock] = [] + for block_hash in block_hashes: + # block_hashes is a chain of block hashes. If a block hash is not + # in the cached_block_hash_to_id, the following block hashes are + # not computed yet for sure. + if cached_block := self.block_pool.get_cached_block(block_hash): + computed_blocks.append(cached_block) + else: + break + return computed_blocks + + def remove_skipped_blocks(self, blocks: list[KVCacheBlock], + num_computed_tokens: int) -> list[KVCacheBlock]: + # No need to remove blocks for full attention. + return [] + + +class SlidingWindowManager(SpecializedManager): + + def __init__(self, kv_cache_spec: SlidingWindowSpec, + block_pool: BlockPool): + super().__init__(kv_cache_spec, block_pool) + self.sliding_window = kv_cache_spec.sliding_window + # The number of contiguous blocks needed for prefix cache hit. + # -1 since the input token itself is also included in the window + self.sliding_window_contiguous_blocks = cdiv( + (kv_cache_spec.sliding_window - 1), self.block_size) + self._null_block = block_pool.null_block + + def find_longest_cache_hit( + self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to + # optimize the time complexity from O(len(block_hashes)) to + # O(len(block_hashes) / sliding_window_contiguous_blocks + + # sliding_window_contiguous_blocks), + # which is good for low cache hit rate scenarios. + computed_blocks = [self._null_block] * len(block_hashes) + num_contiguous_blocks = 0 + + # Search from right to left and early stop when a match is found. + for i in range(len(block_hashes) - 1, -1, -1): + if cached_block := self.block_pool.get_cached_block( + block_hashes[i]): + computed_blocks[i] = cached_block + num_contiguous_blocks += 1 + if (num_contiguous_blocks + >= self.sliding_window_contiguous_blocks): + # Trim the trailing blocks. + # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] + # when sliding_window_contiguous_blocks=2. + del computed_blocks[i + num_contiguous_blocks:] + return computed_blocks + else: + num_contiguous_blocks = 0 + # The first `num_contiguous_blocks` is a cache hit even if + # `num_contiguous_blocks < sliding_window_contiguous_blocks`. + del computed_blocks[num_contiguous_blocks:] + return computed_blocks + + def remove_skipped_blocks(self, blocks: list[KVCacheBlock], + num_computed_tokens: int) -> list[KVCacheBlock]: + # Remove the blocks that are no longer be in the sliding window and + # skipped during the attention computation. + last_useful_token = num_computed_tokens - self.sliding_window + 1 + last_useful_block = last_useful_token // self.block_size + + removed_blocks: list[KVCacheBlock] = [] + for i in range(last_useful_block - 1, -1, -1): + if blocks[i] == self._null_block: + # If the block is already a null block, the blocks before it + # should also have been set to null blocks by the previous calls + # to this function. + break + removed_blocks.append(blocks[i]) + blocks[i] = self._null_block + return removed_blocks + + +spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = { + FullAttentionSpec: FullAttentionManager, + SlidingWindowSpec: SlidingWindowManager, +} + + +def get_specialized_manager(kv_cache_spec: KVCacheSpec, + block_pool: BlockPool) -> SpecializedManager: + manager_class = spec_manager_map[type(kv_cache_spec)] + manager = manager_class(kv_cache_spec, block_pool) + return manager diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 68a1dc15..d915d474 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -33,6 +33,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.executor.abstract import Executor +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -66,8 +67,9 @@ class EngineCore: self.model_executor = executor_class(vllm_config) # Setup KV Caches and update CacheConfig after profiling. - num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches( - vllm_config) + num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ + self._initialize_kv_caches(vllm_config) + vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks @@ -95,10 +97,11 @@ class EngineCore: model_config=vllm_config.model_config, cache_config=vllm_config.cache_config, lora_config=vllm_config.lora_config, + kv_cache_config=kv_cache_config, + structured_output_manager=self.structured_output_manager, include_finished_set=vllm_config.parallel_config.data_parallel_size > 1, log_stats=self.log_stats, - structured_output_manager=self.structured_output_manager, ) # Setup MM Input Mapper. @@ -117,8 +120,8 @@ class EngineCore: self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size) - def _initialize_kv_caches(self, - vllm_config: VllmConfig) -> tuple[int, int]: + def _initialize_kv_caches( + self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() # Get all kv cache needed by the model @@ -143,13 +146,14 @@ class EngineCore: unify_kv_cache_configs(kv_cache_configs) # All workers have the same kv_cache_config except layer names, so use - # an arbitrary one to get the number of blocks. + # an arbitrary one to initialize the scheduler. assert all([ cfg.num_blocks == kv_cache_configs[0].num_blocks for cfg in kv_cache_configs ]) num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 + scheduler_kv_cache_config = kv_cache_configs[0] # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) @@ -157,7 +161,7 @@ class EngineCore: elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) - return num_gpu_blocks, num_cpu_blocks + return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 867b1b61..4fc0844c 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -4,6 +4,7 @@ from dataclasses import dataclass import torch +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils import cdiv, get_dtype_size @@ -43,28 +44,23 @@ class KVCacheSpec: """ raise NotImplementedError - def bytes_for_tokens(self, num_tokens: int) -> int: + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: """ - The KV cache size for `num_tokens` tokens in bytes. Returns the real - memory size after padding `num_tokens` to full blocks. + The maximum possible memory usage of this KV cache in bytes. Returns: - The KV cache size + The KV cache size in bytes """ raise NotImplementedError @dataclass -class FullAttentionSpec(KVCacheSpec): +class AttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int dtype: torch.dtype use_mla: bool - @property - def type_id(self) -> str: - return f"full_attention_{self.block_size}_{self.page_size_bytes}" - @property def page_size_bytes(self) -> int: # For MLA we only store a single latent vector @@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec): return coef * self.block_size * self.num_kv_heads * self.head_size \ * get_dtype_size(self.dtype) - def bytes_for_tokens(self, num_tokens: int) -> int: - return cdiv(num_tokens, self.block_size) * self.page_size_bytes + +@dataclass +class FullAttentionSpec(AttentionSpec): + + @property + def type_id(self) -> str: + return f"full_attention_{self.block_size}_{self.page_size_bytes}" + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + return cdiv(max_model_len, self.block_size) * self.page_size_bytes + + +@dataclass +class SlidingWindowSpec(AttentionSpec): + sliding_window: int + + def __post_init__(self): + assert not self.use_mla, "MLA is not supported for sliding window" + + @property + def type_id(self) -> str: + return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens) + + # During chunked prefill, we allocate KV cache for the last + # `self.sliding_window-1` computed tokens plus the newly scheduled + # tokens. And we won't allocate KV cache for more than `max_model_len` + # tokens. + num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens, + max_model_len) + + # +1 here because the sliding window may not start from the beginning + # of the block. For example, if the block size is 4 and num_token + # is 4, we need two blocks [XXCD] [EF] to store the sliding + # window [CDEF] of 6 tokens. + return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes @dataclass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 43c756b1..637367a7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -28,8 +28,9 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) +from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, + KVCacheConfig, KVCacheSpec, + SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata @@ -1572,7 +1573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # different GPUs, and `kv_cache_config.num_blocks` is set to # the min of all `num_blocks`. Verify it here. assert num_blocks >= kv_cache_config.num_blocks - if isinstance(kv_cache_spec, FullAttentionSpec): + if isinstance(kv_cache_spec, AttentionSpec): kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) @@ -1611,12 +1612,21 @@ class GPUModelRunner(LoRAModelRunnerMixin): # cross-attention assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + sliding_window=attn_module.sliding_window, + use_mla=use_mla) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 7f7318a7..c2edbaf3 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -29,7 +29,7 @@ from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheSpec) + KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput, SamplerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata @@ -353,17 +353,25 @@ class TPUModelRunner: block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in forward_ctx.items(): - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention, MLA. assert isinstance(attn_module, Attention) if attn_module.attn_type == AttentionType.DECODER: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=attn_module.dtype, - use_mla=False, - ) + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + sliding_window=attn_module.sliding_window, + use_mla=False, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + use_mla=False, + ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache.