[Core][v1] Unify allocating slots in prefill and decode in KV cache manager (#12608)
As mentioned in RFC https://github.com/vllm-project/vllm/issues/12254, this PR achieves the task: combine allocate_slots and append_slots. There should be no functionality change, except that in decode, also raise exception when num_tokens is zero (like prefill), and change the unit test case accordingly. @comaniac @rickyyx @WoosukKwon @youkaichao @heheda12345 @simon-mo --------- Signed-off-by: Shawn Du <shawnd200@outlook.com>
This commit is contained in:
parent
abfcdcdf27
commit
f8ece6e17f
@ -164,7 +164,7 @@ def test_decode():
|
|||||||
req0.num_computed_tokens = 55
|
req0.num_computed_tokens = 55
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
req0.append_output_token_ids(8)
|
req0.append_output_token_ids(8)
|
||||||
new_blocks = manager.append_slots(req0, 4)
|
new_blocks = manager.allocate_slots(req0, 4)
|
||||||
assert new_blocks is not None and len(new_blocks) == 0
|
assert new_blocks is not None and len(new_blocks) == 0
|
||||||
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
|
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
|
||||||
|
|
||||||
@ -175,7 +175,7 @@ def test_decode():
|
|||||||
# the preallocated block.
|
# the preallocated block.
|
||||||
for _ in range(5 + 10):
|
for _ in range(5 + 10):
|
||||||
req0.append_output_token_ids(7)
|
req0.append_output_token_ids(7)
|
||||||
new_blocks = manager.append_slots(req0, 15)
|
new_blocks = manager.allocate_slots(req0, 15)
|
||||||
assert new_blocks is not None and len(new_blocks) == 0
|
assert new_blocks is not None and len(new_blocks) == 0
|
||||||
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
|
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ def test_decode():
|
|||||||
# the preallocated block.
|
# the preallocated block.
|
||||||
for _ in range(6 + 11):
|
for _ in range(6 + 11):
|
||||||
req0.append_output_token_ids(12)
|
req0.append_output_token_ids(12)
|
||||||
new_blocks = manager.append_slots(req0, 17)
|
new_blocks = manager.allocate_slots(req0, 17)
|
||||||
# Plus one preallocated block.
|
# Plus one preallocated block.
|
||||||
assert new_blocks is not None and len(new_blocks) == 2
|
assert new_blocks is not None and len(new_blocks) == 2
|
||||||
|
|
||||||
@ -395,12 +395,14 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
|
|||||||
req.num_computed_tokens = block_size
|
req.num_computed_tokens = block_size
|
||||||
assert len(blocks) == 1 + num_preallocated_blocks
|
assert len(blocks) == 1 + num_preallocated_blocks
|
||||||
|
|
||||||
# Assume all computed.
|
# Assume all computed, only when num_preallocate_tokens > 0, we need to
|
||||||
manager.append_slots(req, block_size * (len(blocks) - 1))
|
# consume the previously preallocated blocks.
|
||||||
req.num_computed_tokens = block_size * len(blocks)
|
if num_preallocated_blocks > 0:
|
||||||
|
manager.allocate_slots(req, block_size * (len(blocks) - 1))
|
||||||
|
req.num_computed_tokens = block_size * len(blocks)
|
||||||
|
|
||||||
# Append 1 block.
|
# Append 1 block.
|
||||||
blocks = manager.append_slots(req, block_size)
|
blocks = manager.allocate_slots(req, block_size)
|
||||||
assert len(blocks) == 1 + num_preallocated_blocks
|
assert len(blocks) == 1 + num_preallocated_blocks
|
||||||
|
|
||||||
|
|
||||||
@ -503,7 +505,7 @@ def test_mm_prefix_caching():
|
|||||||
# Append slots without allocating a new block.
|
# Append slots without allocating a new block.
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
req0.append_output_token_ids(8)
|
req0.append_output_token_ids(8)
|
||||||
new_blocks = manager.append_slots(req0, 5)
|
new_blocks = manager.allocate_slots(req0, 5)
|
||||||
assert new_blocks is not None and len(new_blocks) == 0
|
assert new_blocks is not None and len(new_blocks) == 0
|
||||||
|
|
||||||
# The just completed block should have hashes with extra keys.
|
# The just completed block should have hashes with extra keys.
|
||||||
@ -603,7 +605,7 @@ def test_reset_prefix_cache():
|
|||||||
unique_token_ids = [3] * 7
|
unique_token_ids = [3] * 7
|
||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
req0 = make_request("0", all_token_ids)
|
req0 = make_request("0", all_token_ids)
|
||||||
blocks = manager.allocate_slots(req0, 55, [])
|
blocks = manager.allocate_slots(req0, 55)
|
||||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
|
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
|
||||||
|
|
||||||
unique_token_ids = [4] * 7
|
unique_token_ids = [4] * 7
|
||||||
@ -639,7 +641,7 @@ def test_uncache_blocks():
|
|||||||
)
|
)
|
||||||
|
|
||||||
req0 = make_request("0", list(range(30)))
|
req0 = make_request("0", list(range(30)))
|
||||||
blocks = manager.allocate_slots(req0, 30, [])
|
blocks = manager.allocate_slots(req0, 30)
|
||||||
assert [b.block_id for b in blocks] == [0, 1]
|
assert [b.block_id for b in blocks] == [0, 1]
|
||||||
assert len(manager.cached_block_hash_to_block) == 1
|
assert len(manager.cached_block_hash_to_block) == 1
|
||||||
|
|
||||||
@ -648,7 +650,7 @@ def test_uncache_blocks():
|
|||||||
# Simulate speculative tokens.
|
# Simulate speculative tokens.
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
req0.append_output_token_ids(8)
|
req0.append_output_token_ids(8)
|
||||||
manager.append_slots(req0, 5)
|
manager.allocate_slots(req0, 5)
|
||||||
assert len(manager.cached_block_hash_to_block) == 2
|
assert len(manager.cached_block_hash_to_block) == 2
|
||||||
|
|
||||||
# After sampling, assuming only 1 token is accepted.
|
# After sampling, assuming only 1 token is accepted.
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, Iterable, List, Optional, Tuple
|
from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
@ -67,7 +67,8 @@ class KVCacheManager:
|
|||||||
# Mapping from request ID to blocks to track the blocks allocated
|
# Mapping from request ID to blocks to track the blocks allocated
|
||||||
# for each request, so that we can free the blocks when the request
|
# for each request, so that we can free the blocks when the request
|
||||||
# is finished.
|
# is finished.
|
||||||
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
|
self.req_to_blocks: DefaultDict[str,
|
||||||
|
List[KVCacheBlock]] = defaultdict(list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def usage(self) -> float:
|
def usage(self) -> float:
|
||||||
@ -115,33 +116,75 @@ class KVCacheManager:
|
|||||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||||
return computed_blocks, num_computed_tokens
|
return computed_blocks, num_computed_tokens
|
||||||
|
|
||||||
def append_slots(
|
def allocate_slots(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
new_computed_blocks: Optional[List[KVCacheBlock]] = None
|
||||||
) -> Optional[List[KVCacheBlock]]:
|
) -> Optional[List[KVCacheBlock]]:
|
||||||
"""Append slots to the block table of the request.
|
"""Add slots for a request with new tokens to append.
|
||||||
We first append slots to already allocated blocks. If the allocated
|
|
||||||
blocks are not enough, we allocate new blocks.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The request to append slots.
|
request: The request to allocate slots.
|
||||||
num_tokens: The number of tokens to append.
|
num_tokens: The number of tokens to allocate. Note that this does
|
||||||
|
not include the tokens that have already been computed.
|
||||||
|
new_computed_blocks: A list of new computed blocks just hitting the
|
||||||
|
prefix caching.
|
||||||
|
|
||||||
|
Blocks layout:
|
||||||
|
-----------------------------------------------------------------------
|
||||||
|
| < computed > | < new computed > | < new > | < pre-allocated > |
|
||||||
|
-----------------------------------------------------------------------
|
||||||
|
| < required > |
|
||||||
|
--------------------------------------------------
|
||||||
|
| < full > |
|
||||||
|
------------------------------------------------
|
||||||
|
| <new full> |
|
||||||
|
--------------
|
||||||
|
The following *_blocks are illustrated in this layout.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of new blocks if new blocks are allocated, or None
|
A list of new allocated blocks.
|
||||||
if new blocks are required but cannot be allocated.
|
|
||||||
"""
|
"""
|
||||||
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
|
if num_tokens == 0:
|
||||||
|
raise ValueError("num_tokens must be greater than 0")
|
||||||
|
|
||||||
|
new_computed_blocks = new_computed_blocks or []
|
||||||
|
|
||||||
|
# The number of computed tokens is the number of computed tokens plus
|
||||||
|
# the new prefix caching hits
|
||||||
|
num_computed_tokens = (request.num_computed_tokens +
|
||||||
|
len(new_computed_blocks) * self.block_size)
|
||||||
|
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
|
||||||
self.block_size)
|
self.block_size)
|
||||||
req_blocks = self.req_to_blocks[request.request_id]
|
req_blocks = self.req_to_blocks[request.request_id]
|
||||||
|
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
||||||
|
len(new_computed_blocks))
|
||||||
|
|
||||||
num_new_blocks = num_required_blocks - len(req_blocks)
|
# If a computed block of a request is an eviction candidate (in the
|
||||||
if num_new_blocks > self.free_block_queue.num_free_blocks:
|
# free queue and ref_cnt == 0), it cannot be counted as a free block
|
||||||
# Need to allocate new blocks due to insufficient pre-allocated
|
# when allocating this request.
|
||||||
# slots, but we cannot allocate new blocks due to the limit.
|
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
|
||||||
|
if blk.ref_cnt == 0)
|
||||||
|
if (num_new_blocks > self.free_block_queue.num_free_blocks -
|
||||||
|
num_evictable_computed_blocks):
|
||||||
|
# Cannot allocate new blocks
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Touch the computed blocks to make sure they won't be evicted.
|
||||||
|
if self.enable_caching:
|
||||||
|
self._touch(new_computed_blocks)
|
||||||
|
else:
|
||||||
|
assert not new_computed_blocks, (
|
||||||
|
"Computed blocks should be empty when "
|
||||||
|
"prefix caching is disabled")
|
||||||
|
|
||||||
|
# Append the new computed blocks to the request blocks until now to
|
||||||
|
# avoid the case where the new blocks cannot be allocated.
|
||||||
|
req_blocks.extend(new_computed_blocks)
|
||||||
|
|
||||||
|
# Start to handle new blocks
|
||||||
|
|
||||||
if num_new_blocks <= 0:
|
if num_new_blocks <= 0:
|
||||||
# No new block is needed.
|
# No new block is needed.
|
||||||
new_blocks = []
|
new_blocks = []
|
||||||
@ -160,112 +203,29 @@ class KVCacheManager:
|
|||||||
)
|
)
|
||||||
assert num_new_blocks > 0
|
assert num_new_blocks > 0
|
||||||
|
|
||||||
|
# Concatenate the computed block IDs and the new block IDs.
|
||||||
new_blocks = self._get_new_blocks(num_new_blocks)
|
new_blocks = self._get_new_blocks(num_new_blocks)
|
||||||
req_blocks.extend(new_blocks)
|
req_blocks.extend(new_blocks)
|
||||||
|
|
||||||
if not self.enable_caching:
|
if not self.enable_caching:
|
||||||
return new_blocks
|
return new_blocks
|
||||||
|
|
||||||
num_computed_full_blocks = (request.num_computed_tokens //
|
|
||||||
self.block_size)
|
|
||||||
|
|
||||||
# NOTE(rickyx): We are assuming the `num_tokens` are actual
|
# NOTE(rickyx): We are assuming the `num_tokens` are actual
|
||||||
# tokens rather than lookahead slots (e.g. for speculative decoding).
|
# tokens rather than lookahead slots (e.g. for speculative decoding).
|
||||||
# TODO(rickyx): When supporting speculative decoding, we will need to
|
# TODO(rickyx): When supporting speculative decoding, we will need to
|
||||||
# differentiate between them so that we can know how many blocks are
|
# differentiate between them so that we can know how many blocks are
|
||||||
# full after appending the actual tokens.
|
# full after appending the actual tokens.
|
||||||
num_full_blocks_after_append = (request.num_computed_tokens +
|
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
|
||||||
num_tokens) // self.block_size
|
num_computed_full_blocks = num_computed_tokens // self.block_size
|
||||||
assert num_full_blocks_after_append <= len(req_blocks)
|
new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
|
||||||
|
|
||||||
new_full_blocks = req_blocks[
|
|
||||||
num_computed_full_blocks:num_full_blocks_after_append]
|
|
||||||
if new_full_blocks:
|
if new_full_blocks:
|
||||||
self._cache_full_blocks(
|
self._cache_full_blocks(
|
||||||
request=request,
|
request=request,
|
||||||
blk_start_idx=num_computed_full_blocks,
|
blk_start_idx=num_computed_full_blocks,
|
||||||
full_blocks=new_full_blocks,
|
|
||||||
prev_block=req_blocks[num_computed_full_blocks - 1]
|
|
||||||
if num_computed_full_blocks >= 1 else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_blocks
|
|
||||||
|
|
||||||
def allocate_slots(
|
|
||||||
self,
|
|
||||||
request: Request,
|
|
||||||
num_tokens: int,
|
|
||||||
computed_blocks: List[KVCacheBlock],
|
|
||||||
) -> Optional[List[KVCacheBlock]]:
|
|
||||||
"""Allocate slots for a new request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: The request to allocate slots.
|
|
||||||
num_tokens: The number of tokens to allocate. Note that this does
|
|
||||||
not include the tokens that have already been computed.
|
|
||||||
computed_blocks: A list of computed blocks.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of new allocated blocks.
|
|
||||||
"""
|
|
||||||
if num_tokens == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"num_tokens must be greater than 0, got {num_tokens}")
|
|
||||||
|
|
||||||
# If a computed block of a request is an eviction candidate (in the
|
|
||||||
# free queue and ref_cnt == 0), it cannot be counted as a free block
|
|
||||||
# when allocating this request.
|
|
||||||
num_evictable_computed_blocks = sum(1 for blk in computed_blocks
|
|
||||||
if blk.ref_cnt == 0)
|
|
||||||
|
|
||||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
|
||||||
if (num_required_blocks > self.free_block_queue.num_free_blocks -
|
|
||||||
num_evictable_computed_blocks):
|
|
||||||
# Cannot allocate new blocks.
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Touch the computed blocks to make sure they won't be evicted.
|
|
||||||
if self.enable_caching:
|
|
||||||
self._touch(computed_blocks)
|
|
||||||
else:
|
|
||||||
assert not computed_blocks, (
|
|
||||||
"Computed blocks should be empty when "
|
|
||||||
"prefix caching is disabled")
|
|
||||||
|
|
||||||
# Determine the number of new blocks to allocate considering
|
|
||||||
# preallocated blocks.
|
|
||||||
num_new_blocks = min(
|
|
||||||
num_required_blocks + self.num_preallocate_blocks,
|
|
||||||
self.free_block_queue.num_free_blocks,
|
|
||||||
# Should not exceed the maximum number of blocks per request.
|
|
||||||
# This is especially because the block table has the shape
|
|
||||||
# [..., max_num_blocks_per_req].
|
|
||||||
# TODO(woosuk): Check and reject requests if
|
|
||||||
# num_prompt_tokens + max_tokens > max_model_len.
|
|
||||||
self.max_num_blocks_per_req - len(computed_blocks),
|
|
||||||
)
|
|
||||||
assert num_new_blocks > 0
|
|
||||||
|
|
||||||
# Concatenate the computed block IDs and the new block IDs.
|
|
||||||
new_blocks = self._get_new_blocks(num_new_blocks)
|
|
||||||
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks
|
|
||||||
|
|
||||||
if not self.enable_caching:
|
|
||||||
return new_blocks
|
|
||||||
|
|
||||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
|
||||||
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
|
|
||||||
|
|
||||||
new_full_blocks = self.req_to_blocks[
|
|
||||||
request.request_id][len(computed_blocks):num_full_blocks]
|
|
||||||
if new_full_blocks:
|
|
||||||
self._cache_full_blocks(
|
|
||||||
request=request,
|
|
||||||
blk_start_idx=len(computed_blocks),
|
|
||||||
# The new full blocks are the full blocks that are not computed.
|
# The new full blocks are the full blocks that are not computed.
|
||||||
full_blocks=new_full_blocks,
|
full_blocks=new_full_blocks,
|
||||||
prev_block=computed_blocks[-1] if computed_blocks else None,
|
prev_block=(req_blocks[num_computed_full_blocks - 1]
|
||||||
)
|
if num_computed_full_blocks > 0 else None))
|
||||||
|
|
||||||
return new_blocks
|
return new_blocks
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ class Scheduler:
|
|||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
new_blocks = self.kv_cache_manager.append_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request, num_new_tokens)
|
request, num_new_tokens)
|
||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
# The request cannot be scheduled.
|
# The request cannot be scheduled.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user