# SPDX-License-Identifier: Apache-2.0 import torch from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock from vllm.v1.core.specialized_manager import SlidingWindowManager from vllm.v1.kv_cache_interface import SlidingWindowSpec def test_sliding_window_possible_cached_prefix(): sliding_window_spec = SlidingWindowSpec( block_size=2, num_kv_heads=1, head_size=1, dtype=torch.float32, sliding_window=4, use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) manager = SlidingWindowManager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): block_hash_list = [ BlockHashType(i, ()) for i in range(len(block_is_cached)) ] block_pool.cached_block_hash_to_block.clear() # Mock the block pool with the cached blocks for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: block_pool.cached_block_hash_to_block[block_hash] = { i: block_pool.blocks[i + 10] } computed_blocks = manager.find_longest_cache_hit(block_hash_list) assert len(computed_blocks) == expect_length assert all(block == block_pool.null_block for block in computed_blocks[:expect_length - 2]) for i in range(2): if i < expect_length: block_index = expect_length - i - 1 assert computed_blocks[ block_index].block_id == block_index + 10 run_one_case([False] * 10, 0) run_one_case([True], 1) run_one_case([True, False], 1) run_one_case([True, True], 2) run_one_case([True, True, False], 2) run_one_case([True, True, True], 3) run_one_case([True, True, True, False], 3) run_one_case([ True, True, False, True, False, False, True, True, False, True, True, True ], 12) run_one_case([ True, True, False, True, False, False, True, True, False, False, False ], 8) run_one_case([ True, True, False, True, False, False, True, True, False, False, False, True ], 8) def test_sliding_window_remove_skipped_blocks(): sliding_window_spec = SlidingWindowSpec( block_size=2, num_kv_heads=1, head_size=1, dtype=torch.float32, sliding_window=4, use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) manager = SlidingWindowManager(sliding_window_spec, block_pool) null_block_id = block_pool.null_block.block_id def id_to_block_table(ids): return [ KVCacheBlock(id_) if id_ != null_block_id else block_pool.null_block for id_ in ids ] def assert_block_id(block_table, ids): for block, id_ in zip(block_table, ids): if id_ == null_block_id: assert block == block_pool.null_block else: assert block.block_id == id_ original_block_ids = [ 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 ] block_table = id_to_block_table(original_block_ids) removed = manager.remove_skipped_blocks(block_table, 0) assert_block_id(removed, []) assert_block_id(block_table, original_block_ids) # 4 tokens are computed. Only token 0 is out of the sliding window. As # block 1000 also contains token 1 that is in the sliding window, block 1000 # cannot be removed. removed = manager.remove_skipped_blocks(block_table, 4) assert_block_id(removed, []) assert_block_id(block_table, original_block_ids) # 5 tokens are computed. Token 0 & 1 are out of the sliding window. # Block 1000 can be removed. removed = manager.remove_skipped_blocks(block_table, 5) assert_block_id(removed, [original_block_ids[0]]) assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) # 6 tokens are computed. Token 0-2 are out of the sliding window. # Cannot remove new block as the block 1001 is still used by token 3. removed = manager.remove_skipped_blocks(block_table, 6) assert_block_id(removed, []) assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) # 7 tokens are computed. Token 0-3 are out of the sliding window. # Block 1001 can be removed and block 1000 is already removed. removed = manager.remove_skipped_blocks(block_table, 7) assert_block_id(removed, [original_block_ids[1]]) assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:]) # 11 tokens are computed. Token 0-7 are out of the sliding window. # Block 1002 & 1003 can be removed now. Block 1003 represents a longer # sequence, and is expected to be evicted earlier than 1002, so the order # of removed blocks should be [1003, 1002]. removed = manager.remove_skipped_blocks(block_table, 11) assert_block_id(removed, [original_block_ids[3], original_block_ids[2]]) assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])