[V1][Prefix Cache] Move the logic of num_computed_tokens into KVCacheManager (#12003)
This commit is contained in:
parent
3f9b7ab9f5
commit
994fc655b7
@ -49,9 +49,10 @@ def test_prefill():
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(req0.kv_block_hashes) == 3
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
|
||||
@ -73,9 +74,10 @@ def test_prefill():
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [3] * 5
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(req1.kv_block_hashes) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [5, 6]
|
||||
@ -91,7 +93,7 @@ def test_prefill():
|
||||
# All blocks should be available.
|
||||
assert manager.free_block_queue.num_free_blocks == 10
|
||||
# The order should be
|
||||
# [unallocated (7, 8)]
|
||||
# [unallocated (7, 8, 9)]
|
||||
# [unique_req0 (4, 3)]
|
||||
# [unique_req1 (6, 5)]
|
||||
# [common (2, 1, 0)]
|
||||
@ -103,9 +105,10 @@ def test_prefill():
|
||||
# Incomplete 1 block (6 tokens)
|
||||
unique_token_ids = [3] * 6
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||
computed_blocks = manager.get_computed_blocks(req2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(req2.kv_block_hashes) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [7, 8]
|
||||
@ -123,8 +126,9 @@ def test_prefill():
|
||||
|
||||
# Cache miss and eviction.
|
||||
req3 = make_request("3", [99] * (16 * 9))
|
||||
computed_blocks = manager.get_computed_blocks(req3)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
|
||||
@ -150,8 +154,9 @@ def test_decode():
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
req0 = make_request("0", common_token_ids + unique_token_ids)
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
|
||||
@ -197,16 +202,18 @@ def test_evict():
|
||||
|
||||
last_token_id = 5 * 16 + 7
|
||||
req0 = make_request("0", list(range(last_token_id)))
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
|
||||
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated
|
||||
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
last_token_id + 3 * 16)))
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
|
||||
assert len(blocks) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
@ -222,8 +229,9 @@ def test_evict():
|
||||
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)))
|
||||
computed_blocks = manager.get_computed_blocks(req2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1]
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [6, 5]
|
||||
assert manager.free_block_queue.num_free_blocks == 6
|
||||
@ -247,8 +255,9 @@ def test_hash_block_correct_reuse():
|
||||
# Allocate 1 block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks = manager.get_computed_blocks(req)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
|
||||
@ -258,8 +267,9 @@ def test_hash_block_correct_reuse():
|
||||
# Allocate a new block that's not full, make sure hash info on the
|
||||
# block is cleared.
|
||||
req = make_request("1", list(range(num_tokens - 1)))
|
||||
computed_blocks = manager.get_computed_blocks(req)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
|
||||
@ -284,16 +294,18 @@ def test_computed_blocks_not_evicted():
|
||||
# Allocate a block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req0 = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 0
|
||||
|
||||
# Allocate another block.
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 1
|
||||
@ -305,9 +317,10 @@ def test_computed_blocks_not_evicted():
|
||||
# Now if we have a cache hit on the first block, we should evict the second
|
||||
# cached block rather than the first one.
|
||||
req2 = make_request("2", list(range(num_tokens * 2)))
|
||||
computed_blocks = manager.get_computed_blocks(req2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks) == 1
|
||||
assert computed_blocks[0].block_id == 0
|
||||
assert num_computed_tokens == block_size
|
||||
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
computed_blocks)
|
||||
@ -331,8 +344,9 @@ def test_basic_prefix_caching_disabled():
|
||||
|
||||
req1 = make_request("1", list(range(10))) # 2 blocks and some more
|
||||
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 10, computed_blocks)
|
||||
assert len(blocks) == 3
|
||||
|
||||
@ -341,15 +355,17 @@ def test_basic_prefix_caching_disabled():
|
||||
|
||||
# No caching.
|
||||
req2 = make_request("2", list(range(16))) # shared prefix
|
||||
computed_blocks = manager.get_computed_blocks(req2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 16, computed_blocks)
|
||||
assert len(blocks) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)))
|
||||
computed_blocks = manager.get_computed_blocks(req3)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 4, computed_blocks)
|
||||
assert not blocks
|
||||
|
||||
@ -371,8 +387,9 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
|
||||
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
|
||||
|
||||
req = make_request("0", list(range(block_size * 30)))
|
||||
computed_blocks = manager.get_computed_blocks(req)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
# Just ask for 1 block.
|
||||
blocks = manager.allocate_slots(req, block_size, computed_blocks)
|
||||
req.num_computed_tokens = block_size
|
||||
@ -469,10 +486,11 @@ def test_mm_prefix_caching():
|
||||
all_token_ids,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
assert len(req0.kv_block_hashes) == 3
|
||||
assert req0.kv_block_hashes[0].extra_keys == ("aaa", )
|
||||
assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb")
|
||||
@ -503,8 +521,9 @@ def test_mm_prefix_caching():
|
||||
all_token_ids,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks) == 3
|
||||
assert num_computed_tokens == 3 * 16
|
||||
|
||||
|
||||
def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
@ -527,15 +546,17 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# | Common-0 | Common-1 | Common-2 | ... |
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
req0 = make_request("0", common_token_ids)
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req0, 48, computed_blocks)
|
||||
block_part0 = manager.req_to_blocks[req0.request_id]
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||
req1 = make_request("1", common_token_ids * 2)
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert computed_blocks == block_part0
|
||||
assert num_computed_tokens == 3 * 16
|
||||
manager.allocate_slots(req1, 48, computed_blocks)
|
||||
block_part1 = manager.req_to_blocks[req1.request_id]
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||
@ -547,8 +568,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
||||
req2 = make_request("2", [7] * block_size * 2)
|
||||
computed_blocks = manager.get_computed_blocks(req2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req2, block_size * 2, computed_blocks)
|
||||
|
||||
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
|
||||
@ -556,8 +578,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# In this case, the ref_cnt of the computed blocks should not be changed.
|
||||
assert manager.free_block_queue.num_free_blocks == 5
|
||||
req3 = make_request("3", common_token_ids * 3)
|
||||
computed_blocks = manager.get_computed_blocks(req3)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert computed_blocks == block_part1
|
||||
assert num_computed_tokens == 6 * 16
|
||||
# Req3 cannot be allocated.
|
||||
assert manager.allocate_slots(req3, 48, computed_blocks) is None
|
||||
# Block 0-2 are used by Req 1.
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
@ -69,7 +69,8 @@ class KVCacheManager:
|
||||
# is finished.
|
||||
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
|
||||
|
||||
def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
|
||||
def get_computed_blocks(
|
||||
self, request: Request) -> Tuple[List[KVCacheBlock], int]:
|
||||
"""Get the computed (cached) blocks for the request.
|
||||
Note that the computed blocks must be full.
|
||||
|
||||
@ -77,11 +78,13 @@ class KVCacheManager:
|
||||
request: The request to get the computed blocks.
|
||||
|
||||
Returns:
|
||||
A list of blocks that are computed for the request.
|
||||
A tuple containing:
|
||||
- A list of blocks that are computed for the request.
|
||||
- The number of computed tokens.
|
||||
"""
|
||||
if not self.enable_caching:
|
||||
# Prefix caching is disabled.
|
||||
return []
|
||||
return [], 0
|
||||
|
||||
computed_blocks = []
|
||||
|
||||
@ -101,7 +104,11 @@ class KVCacheManager:
|
||||
else:
|
||||
break
|
||||
|
||||
return computed_blocks
|
||||
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
||||
# sharing, `num_computed_tokens` is always a multiple of
|
||||
# `block_size`.
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
return computed_blocks, num_computed_tokens
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
|
@ -184,12 +184,8 @@ class Scheduler:
|
||||
|
||||
request = self.waiting[0]
|
||||
# Get already-cached tokens.
|
||||
computed_blocks = self.kv_cache_manager.get_computed_blocks(
|
||||
request)
|
||||
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
||||
# sharing, `num_computed_tokens` is always a multiple of
|
||||
# `block_size`.
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
computed_blocks, num_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(request)
|
||||
# Number of tokens to be scheduled.
|
||||
# We use `request.num_tokens` instead of
|
||||
# `request.num_prompt_tokens` to consider the resumed requests,
|
||||
|
Loading…
x
Reference in New Issue
Block a user