[V1][Prefix Cache] Move the logic of num_computed_tokens into KVCacheManager (#12003)

This commit is contained in:
Chen Zhang 2025-01-15 15:55:30 +08:00 committed by GitHub
parent 3f9b7ab9f5
commit 994fc655b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 61 additions and 35 deletions

View File

@ -49,9 +49,10 @@ def test_prefill():
unique_token_ids = [3] * 7 unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids) req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.kv_block_hashes) == 3 assert len(req0.kv_block_hashes) == 3
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
@ -73,9 +74,10 @@ def test_prefill():
# Incomplete 1 block (5 tokens) # Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5 unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids) req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.kv_block_hashes) == 3 assert len(req1.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2] assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6] assert [b.block_id for b in blocks] == [5, 6]
@ -91,7 +93,7 @@ def test_prefill():
# All blocks should be available. # All blocks should be available.
assert manager.free_block_queue.num_free_blocks == 10 assert manager.free_block_queue.num_free_blocks == 10
# The order should be # The order should be
# [unallocated (7, 8)] # [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)] # [unique_req0 (4, 3)]
# [unique_req1 (6, 5)] # [unique_req1 (6, 5)]
# [common (2, 1, 0)] # [common (2, 1, 0)]
@ -103,9 +105,10 @@ def test_prefill():
# Incomplete 1 block (6 tokens) # Incomplete 1 block (6 tokens)
unique_token_ids = [3] * 6 unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids) req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(req2.kv_block_hashes) == 3 assert len(req2.kv_block_hashes) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2] assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8] assert [b.block_id for b in blocks] == [7, 8]
@ -123,8 +126,9 @@ def test_prefill():
# Cache miss and eviction. # Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9)) req3 = make_request("3", [99] * (16 * 9))
computed_blocks = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order. # This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
@ -150,8 +154,9 @@ def test_decode():
# Incomplete 1 block (7 tokens) # Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7 unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids) req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4] assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
@ -197,16 +202,18 @@ def test_evict():
last_token_id = 5 * 16 + 7 last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id))) req0 = make_request("0", list(range(last_token_id)))
computed_blocks = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated
# 3 blocks. # 3 blocks.
req1 = make_request("1", list(range(last_token_id, req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16))) last_token_id + 3 * 16)))
computed_blocks = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16 last_token_id += 3 * 16
@ -222,8 +229,9 @@ def test_evict():
# Touch the first 2 blocks. # Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3))) req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1] assert [b.block_id for b in computed_blocks] == [0, 1]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks) blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5] assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6 assert manager.free_block_queue.num_free_blocks == 6
@ -247,8 +255,9 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it. # Allocate 1 block and cache it.
num_tokens = block_size * 1 num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens))) req = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks) blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1 assert len(blocks) == 1
@ -258,8 +267,9 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the # Allocate a new block that's not full, make sure hash info on the
# block is cleared. # block is cleared.
req = make_request("1", list(range(num_tokens - 1))) req = make_request("1", list(range(num_tokens - 1)))
computed_blocks = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1 assert len(blocks) == 1
@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it. # Allocate a block and cache it.
num_tokens = block_size * 1 num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens))) req0 = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1 assert len(blocks) == 1
assert blocks[0].block_id == 0 assert blocks[0].block_id == 0
# Allocate another block. # Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1 assert len(blocks) == 1
assert blocks[0].block_id == 1 assert blocks[0].block_id == 1
@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second # Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one. # cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2))) req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1 assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0 assert computed_blocks[0].block_id == 0
assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks) computed_blocks)
@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled():
req1 = make_request("1", list(range(10))) # 2 blocks and some more req1 = make_request("1", list(range(10))) # 2 blocks and some more
computed_blocks = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks) blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3 assert len(blocks) == 3
@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled():
# No caching. # No caching.
req2 = make_request("2", list(range(16))) # shared prefix req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks) blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4 assert len(blocks) == 4
# New requests should not have any blocks. # New requests should not have any blocks.
req3 = make_request("3", list(range(4))) req3 = make_request("3", list(range(4)))
computed_blocks = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks) blocks = manager.allocate_slots(req3, 4, computed_blocks)
assert not blocks assert not blocks
@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size) num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
req = make_request("0", list(range(block_size * 30))) req = make_request("0", list(range(block_size * 30)))
computed_blocks = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
# Just ask for 1 block. # Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks) blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size req.num_computed_tokens = block_size
@ -469,10 +486,11 @@ def test_mm_prefix_caching():
all_token_ids, all_token_ids,
mm_positions=mm_positions, mm_positions=mm_positions,
mm_hashes=mm_hashes) mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys. # Completed block should have hashes with extra keys.
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
assert len(req0.kv_block_hashes) == 3 assert len(req0.kv_block_hashes) == 3
assert req0.kv_block_hashes[0].extra_keys == ("aaa", ) assert req0.kv_block_hashes[0].extra_keys == ("aaa", )
assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb") assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb")
@ -503,8 +521,9 @@ def test_mm_prefix_caching():
all_token_ids, all_token_ids,
mm_positions=mm_positions, mm_positions=mm_positions,
mm_hashes=mm_hashes) mm_hashes=mm_hashes)
computed_blocks = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3 assert len(computed_blocks) == 3
assert num_computed_tokens == 3 * 16
def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_prefill_not_enough_free_blocks_with_computed_blocks():
@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | ... | # | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)] common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids) req0 = make_request("0", common_token_ids)
computed_blocks = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks) manager.allocate_slots(req0, 48, computed_blocks)
block_part0 = manager.req_to_blocks[req0.request_id] block_part0 = manager.req_to_blocks[req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2) req1 = make_request("1", common_token_ids * 2)
computed_blocks = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks == block_part0 assert computed_blocks == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks) manager.allocate_slots(req1, 48, computed_blocks)
block_part1 = manager.req_to_blocks[req1.request_id] block_part1 = manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
@ -547,8 +568,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... | # | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2) req2 = make_request("2", [7] * block_size * 2)
computed_blocks = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks assert not computed_blocks
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks) manager.allocate_slots(req2, block_size * 2, computed_blocks)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
@ -556,8 +578,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# In this case, the ref_cnt of the computed blocks should not be changed. # In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.free_block_queue.num_free_blocks == 5 assert manager.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3) req3 = make_request("3", common_token_ids * 3)
computed_blocks = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1 assert computed_blocks == block_part1
assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated. # Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None assert manager.allocate_slots(req3, 48, computed_blocks) is None
# Block 0-2 are used by Req 1. # Block 0-2 are used by Req 1.

View File

@ -1,5 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
@ -69,7 +69,8 @@ class KVCacheManager:
# is finished. # is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {} self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: def get_computed_blocks(
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
"""Get the computed (cached) blocks for the request. """Get the computed (cached) blocks for the request.
Note that the computed blocks must be full. Note that the computed blocks must be full.
@ -77,11 +78,13 @@ class KVCacheManager:
request: The request to get the computed blocks. request: The request to get the computed blocks.
Returns: Returns:
A list of blocks that are computed for the request. A tuple containing:
- A list of blocks that are computed for the request.
- The number of computed tokens.
""" """
if not self.enable_caching: if not self.enable_caching:
# Prefix caching is disabled. # Prefix caching is disabled.
return [] return [], 0
computed_blocks = [] computed_blocks = []
@ -101,7 +104,11 @@ class KVCacheManager:
else: else:
break break
return computed_blocks # NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
def append_slots( def append_slots(
self, self,

View File

@ -184,12 +184,8 @@ class Scheduler:
request = self.waiting[0] request = self.waiting[0]
# Get already-cached tokens. # Get already-cached tokens.
computed_blocks = self.kv_cache_manager.get_computed_blocks( computed_blocks, num_computed_tokens = \
request) self.kv_cache_manager.get_computed_blocks(request)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
# Number of tokens to be scheduled. # Number of tokens to be scheduled.
# We use `request.num_tokens` instead of # We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests, # `request.num_prompt_tokens` to consider the resumed requests,