[Core][v1] Unify allocating slots in prefill and decode in KV cache manager (#12608)

As mentioned in RFC https://github.com/vllm-project/vllm/issues/12254,
this PR achieves the task: combine allocate_slots and append_slots.

There should be no functionality change, except that in decode, also
raise exception when num_tokens is zero (like prefill), and change the
unit test case accordingly.

@comaniac @rickyyx @WoosukKwon @youkaichao @heheda12345 @simon-mo

---------

Signed-off-by: Shawn Du <shawnd200@outlook.com>
This commit is contained in:
Shawn Du 2025-02-02 16:40:58 +08:00 committed by GitHub
parent abfcdcdf27
commit f8ece6e17f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 78 additions and 116 deletions

View File

@ -164,7 +164,7 @@ def test_decode():
req0.num_computed_tokens = 55 req0.num_computed_tokens = 55
for _ in range(4): for _ in range(4):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4) new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
@ -175,7 +175,7 @@ def test_decode():
# the preallocated block. # the preallocated block.
for _ in range(5 + 10): for _ in range(5 + 10):
req0.append_output_token_ids(7) req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15) new_blocks = manager.allocate_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 0
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
@ -185,7 +185,7 @@ def test_decode():
# the preallocated block. # the preallocated block.
for _ in range(6 + 11): for _ in range(6 + 11):
req0.append_output_token_ids(12) req0.append_output_token_ids(12)
new_blocks = manager.append_slots(req0, 17) new_blocks = manager.allocate_slots(req0, 17)
# Plus one preallocated block. # Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2 assert new_blocks is not None and len(new_blocks) == 2
@ -395,12 +395,14 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
req.num_computed_tokens = block_size req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks assert len(blocks) == 1 + num_preallocated_blocks
# Assume all computed. # Assume all computed, only when num_preallocate_tokens > 0, we need to
manager.append_slots(req, block_size * (len(blocks) - 1)) # consume the previously preallocated blocks.
if num_preallocated_blocks > 0:
manager.allocate_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks) req.num_computed_tokens = block_size * len(blocks)
# Append 1 block. # Append 1 block.
blocks = manager.append_slots(req, block_size) blocks = manager.allocate_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks assert len(blocks) == 1 + num_preallocated_blocks
@ -503,7 +505,7 @@ def test_mm_prefix_caching():
# Append slots without allocating a new block. # Append slots without allocating a new block.
for _ in range(5): for _ in range(5):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 5) new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 0
# The just completed block should have hashes with extra keys. # The just completed block should have hashes with extra keys.
@ -603,7 +605,7 @@ def test_reset_prefix_cache():
unique_token_ids = [3] * 7 unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids) req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55, []) blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [0, 1, 2, 3] assert [b.block_id for b in blocks] == [0, 1, 2, 3]
unique_token_ids = [4] * 7 unique_token_ids = [4] * 7
@ -639,7 +641,7 @@ def test_uncache_blocks():
) )
req0 = make_request("0", list(range(30))) req0 = make_request("0", list(range(30)))
blocks = manager.allocate_slots(req0, 30, []) blocks = manager.allocate_slots(req0, 30)
assert [b.block_id for b in blocks] == [0, 1] assert [b.block_id for b in blocks] == [0, 1]
assert len(manager.cached_block_hash_to_block) == 1 assert len(manager.cached_block_hash_to_block) == 1
@ -648,7 +650,7 @@ def test_uncache_blocks():
# Simulate speculative tokens. # Simulate speculative tokens.
for _ in range(5): for _ in range(5):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
manager.append_slots(req0, 5) manager.allocate_slots(req0, 5)
assert len(manager.cached_block_hash_to_block) == 2 assert len(manager.cached_block_hash_to_block) == 2
# After sampling, assuming only 1 token is accepted. # After sampling, assuming only 1 token is accepted.

View File

@ -1,5 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict, Iterable, List, Optional, Tuple from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
@ -67,7 +67,8 @@ class KVCacheManager:
# Mapping from request ID to blocks to track the blocks allocated # Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request # for each request, so that we can free the blocks when the request
# is finished. # is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {} self.req_to_blocks: DefaultDict[str,
List[KVCacheBlock]] = defaultdict(list)
@property @property
def usage(self) -> float: def usage(self) -> float:
@ -115,33 +116,75 @@ class KVCacheManager:
num_computed_tokens = len(computed_blocks) * self.block_size num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens return computed_blocks, num_computed_tokens
def append_slots( def allocate_slots(
self, self,
request: Request, request: Request,
num_tokens: int, num_tokens: int,
new_computed_blocks: Optional[List[KVCacheBlock]] = None
) -> Optional[List[KVCacheBlock]]: ) -> Optional[List[KVCacheBlock]]:
"""Append slots to the block table of the request. """Add slots for a request with new tokens to append.
We first append slots to already allocated blocks. If the allocated
blocks are not enough, we allocate new blocks.
Args: Args:
request: The request to append slots. request: The request to allocate slots.
num_tokens: The number of tokens to append. num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
Blocks layout:
-----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > |
-----------------------------------------------------------------------
| < required > |
--------------------------------------------------
| < full > |
------------------------------------------------
| <new full> |
--------------
The following *_blocks are illustrated in this layout.
Returns: Returns:
A list of new blocks if new blocks are allocated, or None A list of new allocated blocks.
if new blocks are required but cannot be allocated.
""" """
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens, if num_tokens == 0:
raise ValueError("num_tokens must be greater than 0")
new_computed_blocks = new_computed_blocks or []
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size) self.block_size)
req_blocks = self.req_to_blocks[request.request_id] req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))
num_new_blocks = num_required_blocks - len(req_blocks) # If a computed block of a request is an eviction candidate (in the
if num_new_blocks > self.free_block_queue.num_free_blocks: # free queue and ref_cnt == 0), it cannot be counted as a free block
# Need to allocate new blocks due to insufficient pre-allocated # when allocating this request.
# slots, but we cannot allocate new blocks due to the limit. num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
if blk.ref_cnt == 0)
if (num_new_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
# Cannot allocate new blocks
return None return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self._touch(new_computed_blocks)
else:
assert not new_computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
req_blocks.extend(new_computed_blocks)
# Start to handle new blocks
if num_new_blocks <= 0: if num_new_blocks <= 0:
# No new block is needed. # No new block is needed.
new_blocks = [] new_blocks = []
@ -160,112 +203,29 @@ class KVCacheManager:
) )
assert num_new_blocks > 0 assert num_new_blocks > 0
# Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks) new_blocks = self._get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks) req_blocks.extend(new_blocks)
if not self.enable_caching: if not self.enable_caching:
return new_blocks return new_blocks
num_computed_full_blocks = (request.num_computed_tokens //
self.block_size)
# NOTE(rickyx): We are assuming the `num_tokens` are actual # NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding). # tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to # TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are # differentiate between them so that we can know how many blocks are
# full after appending the actual tokens. # full after appending the actual tokens.
num_full_blocks_after_append = (request.num_computed_tokens + num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
num_tokens) // self.block_size num_computed_full_blocks = num_computed_tokens // self.block_size
assert num_full_blocks_after_append <= len(req_blocks) new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
new_full_blocks = req_blocks[
num_computed_full_blocks:num_full_blocks_after_append]
if new_full_blocks: if new_full_blocks:
self._cache_full_blocks( self._cache_full_blocks(
request=request, request=request,
blk_start_idx=num_computed_full_blocks, blk_start_idx=num_computed_full_blocks,
full_blocks=new_full_blocks,
prev_block=req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks >= 1 else None,
)
return new_blocks
def allocate_slots(
self,
request: Request,
num_tokens: int,
computed_blocks: List[KVCacheBlock],
) -> Optional[List[KVCacheBlock]]:
"""Allocate slots for a new request.
Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: A list of computed blocks.
Returns:
A list of new allocated blocks.
"""
if num_tokens == 0:
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
if blk.ref_cnt == 0)
num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks):
# Cannot allocate new blocks.
return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self._touch(computed_blocks)
else:
assert not computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Determine the number of new blocks to allocate considering
# preallocated blocks.
num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
# TODO(woosuk): Check and reject requests if
# num_prompt_tokens + max_tokens > max_model_len.
self.max_num_blocks_per_req - len(computed_blocks),
)
assert num_new_blocks > 0
# Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks)
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks
if not self.enable_caching:
return new_blocks
num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
new_full_blocks = self.req_to_blocks[
request.request_id][len(computed_blocks):num_full_blocks]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed. # The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks, full_blocks=new_full_blocks,
prev_block=computed_blocks[-1] if computed_blocks else None, prev_block=(req_blocks[num_computed_full_blocks - 1]
) if num_computed_full_blocks > 0 else None))
return new_blocks return new_blocks

View File

@ -138,7 +138,7 @@ class Scheduler:
assert num_new_tokens > 0 assert num_new_tokens > 0
while True: while True:
new_blocks = self.kv_cache_manager.append_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens) request, num_new_tokens)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.