[v1] Move block pool operations to a separate class (#13973)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
b526ca6726
commit
28943d36ce
@ -1,12 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Compare the with and without prefix caching."""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
hash_block_tokens)
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
@ -62,14 +66,14 @@ def test_prefill():
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
assert manager.block_pool[block_id].block_hash == block_hash
|
||||
assert manager.block_pool[block_id].ref_cnt == 1
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
# Check partial/preallocated block metadata
|
||||
for block_id in (3, 4):
|
||||
assert manager.block_pool[block_id].block_hash is None
|
||||
assert manager.block_pool[block_id].ref_cnt == 1
|
||||
assert manager.block_pool.blocks[block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
# Cache hit in the common prefix when the original block is still in use.
|
||||
# Incomplete 1 block (5 tokens)
|
||||
@ -86,20 +90,21 @@ def test_prefill():
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 3 free blocks left.
|
||||
assert manager.free_block_queue.num_free_blocks == 3
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 3
|
||||
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
# All blocks should be available.
|
||||
assert manager.free_block_queue.num_free_blocks == 10
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
||||
# The order should be
|
||||
# [unallocated (7, 8, 9)]
|
||||
# [unique_req0 (4, 3)]
|
||||
# [unique_req1 (6, 5)]
|
||||
# [common (2, 1, 0)]
|
||||
assert [
|
||||
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
|
||||
b.block_id
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
|
||||
|
||||
# Cache hit in the common prefix when the original block is already free.
|
||||
@ -116,12 +121,14 @@ def test_prefill():
|
||||
|
||||
# Although we only have 5 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
assert manager.free_block_queue.num_free_blocks == 5
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
assert all([
|
||||
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks()
|
||||
b.ref_cnt == 0
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
])
|
||||
assert len([b
|
||||
for b in manager.free_block_queue.get_all_free_blocks()]) == 5
|
||||
assert len([
|
||||
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
]) == 5
|
||||
|
||||
manager.free(req2)
|
||||
|
||||
@ -133,9 +140,9 @@ def test_prefill():
|
||||
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
|
||||
assert manager.free_block_queue.num_free_blocks == 0
|
||||
assert manager.free_block_queue.free_list_head is None
|
||||
assert manager.free_block_queue.free_list_tail is None
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.free_list_head is None
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
|
||||
|
||||
def test_decode():
|
||||
@ -219,13 +226,14 @@ def test_evict():
|
||||
assert len(blocks) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
|
||||
assert manager.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
assert manager.free_block_queue.num_free_blocks == 10
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
||||
assert [
|
||||
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
|
||||
b.block_id
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
|
||||
|
||||
# Touch the first 2 blocks.
|
||||
@ -235,7 +243,7 @@ def test_evict():
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [6, 5]
|
||||
assert manager.free_block_queue.num_free_blocks == 6
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 6
|
||||
|
||||
|
||||
def test_hash_block_correct_reuse():
|
||||
@ -274,7 +282,7 @@ def test_hash_block_correct_reuse():
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
|
||||
assert manager.block_pool[blocks[0].block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None
|
||||
|
||||
|
||||
def test_computed_blocks_not_evicted():
|
||||
@ -413,13 +421,9 @@ def test_cache_blocks():
|
||||
function of KVCacheManager.
|
||||
"""
|
||||
block_size = 4
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=5,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
# Req:
|
||||
# Block 0: [0, 1, 2, 3]
|
||||
@ -430,26 +434,31 @@ def test_cache_blocks():
|
||||
|
||||
# Test that blocks are cached correctly for 2 full blocks from the start.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||
block_hashes: List[BlockHashType] = []
|
||||
|
||||
manager._cache_full_blocks(
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blk_start_idx=0,
|
||||
full_blocks=blocks,
|
||||
prev_block=None,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
assert len(manager.cached_block_hash_to_block) == 2
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Test that blocks that don't start from the beginning are cached correctly.
|
||||
blocks = [KVCacheBlock(block_id=2)]
|
||||
manager._cache_full_blocks(
|
||||
blocks += [KVCacheBlock(block_id=2)]
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blk_start_idx=2,
|
||||
full_blocks=blocks,
|
||||
prev_block=None,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=2,
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
)
|
||||
assert len(manager.cached_block_hash_to_block) == 3
|
||||
assert len(block_pool.cached_block_hash_to_block) == 3
|
||||
assert blocks[0].block_hash is not None
|
||||
|
||||
|
||||
@ -580,7 +589,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
|
||||
# but it cannot be allocated due to insufficient free blocks (2).
|
||||
# In this case, the ref_cnt of the computed blocks should not be changed.
|
||||
assert manager.free_block_queue.num_free_blocks == 5
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
req3 = make_request("3", common_token_ids * 3)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert computed_blocks == block_part1
|
||||
@ -621,12 +630,12 @@ def test_reset_prefix_cache():
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
assert manager.cached_block_hash_to_block
|
||||
assert manager.block_pool.cached_block_hash_to_block
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
assert manager.reset_prefix_cache()
|
||||
assert not manager.cached_block_hash_to_block
|
||||
assert all([blk.block_hash is None for blk in manager.block_pool])
|
||||
assert not manager.block_pool.cached_block_hash_to_block
|
||||
assert all([blk.block_hash is None for blk in manager.block_pool.blocks])
|
||||
|
285
vllm/v1/core/block_pool.py
Normal file
285
vllm/v1/core/block_pool.py
Normal file
@ -0,0 +1,285 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockPool:
|
||||
"""BlockPool that manages KVCacheBlocks.
|
||||
It provides methods to allocate, free and cache the kv cache blocks. The
|
||||
free_block_queue stores the free blocks in eviction order to enable
|
||||
allocation, free, and cache eviction. The cached_block_hash_to_block
|
||||
maps between block hash and cached block to support finding cached blocks
|
||||
by their block hash.
|
||||
|
||||
Args:
|
||||
num_gpu_blocks: The number of blocks in the pool.
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
"""
|
||||
|
||||
def __init__(self, num_gpu_blocks: int, enable_caching: bool):
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.enable_caching = enable_caching
|
||||
# All kv-cache blocks.
|
||||
self.blocks: List[KVCacheBlock] = [
|
||||
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
||||
]
|
||||
# Free block queue that constructs and manipulates a doubly linked
|
||||
# list of free blocks (including eviction candidates when caching is
|
||||
# enabled).
|
||||
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
|
||||
|
||||
# {block_hash: {block ID: block}}. A cached block is
|
||||
# a full block with a block hash that can be used for prefix caching.
|
||||
# The cached block may be used by running requests or in the
|
||||
# free_block_queue that could potentially be evicted.
|
||||
# NOTE: We currently don't de-duplicate the blocks in the cache,
|
||||
# meaning that if a block becomes full and is cached, we don't check
|
||||
# if there is already an identical block in the cache. This is because
|
||||
# we want to make sure the allocated block IDs won't change so that
|
||||
# block tables are append-only.
|
||||
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
|
||||
int, KVCacheBlock]] = defaultdict(dict)
|
||||
|
||||
def get_cached_block(self,
|
||||
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
||||
"""Get a cached block by the block hash, or None if cache miss.
|
||||
If there are duplicated blocks, we return the first block in the cache.
|
||||
|
||||
Args:
|
||||
block_hash: The hash value of the block.
|
||||
|
||||
Returns:
|
||||
The cached block if it exists, or None.
|
||||
"""
|
||||
if block_hash in self.cached_block_hash_to_block:
|
||||
first_block_id = list(
|
||||
self.cached_block_hash_to_block[block_hash].keys())[0]
|
||||
return self.cached_block_hash_to_block[block_hash][first_block_id]
|
||||
return None
|
||||
|
||||
def cache_full_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
blocks: List[KVCacheBlock],
|
||||
block_hashes: List[BlockHashType],
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
This function takes a list of blocks that will have their block hash
|
||||
metadata to be updated and cached. Given a request, it computes the
|
||||
block hashes for the blocks starting from `num_cached_blocks` to
|
||||
`num_full_blocks`, updating the metadata for each block
|
||||
and caching them in the `cached_block_hash_to_block`.
|
||||
|
||||
Args:
|
||||
request: The request to cache the blocks.
|
||||
blocks: All blocks in the request.
|
||||
block_hashes: Block hashes of the blocks in the request. Note that
|
||||
this list may be shorter than the blocks list. In this case the
|
||||
missed block hash will be computed in this function.
|
||||
num_cached_blocks: The number of blocks that are already cached.
|
||||
num_full_blocks: The number of blocks that are full and should
|
||||
be cached after this function.
|
||||
block_size: Number of tokens in each block.
|
||||
"""
|
||||
if num_cached_blocks == num_full_blocks:
|
||||
return
|
||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||
assert len(block_hashes) >= num_cached_blocks
|
||||
new_block_hashes = block_hashes[num_cached_blocks:]
|
||||
|
||||
# Update the new blocks with the block hashes through the chain.
|
||||
if num_cached_blocks == 0:
|
||||
prev_block_hash_value = None
|
||||
else:
|
||||
prev_block = blocks[num_cached_blocks - 1]
|
||||
assert prev_block.block_hash is not None
|
||||
prev_block_hash_value = prev_block.block_hash.hash_value
|
||||
|
||||
# Find the first uncached block.
|
||||
# FIXME: num_cached_blocks should be corrected by the caller
|
||||
# so this should never happen.
|
||||
offset = 0
|
||||
for blk in new_full_blocks:
|
||||
if blk.block_hash is None:
|
||||
break
|
||||
else:
|
||||
prev_block_hash_value = blk.block_hash.hash_value
|
||||
offset += 1
|
||||
else:
|
||||
# All blocks are cached.
|
||||
return
|
||||
|
||||
for i, blk in enumerate(new_full_blocks[offset:]):
|
||||
blk_idx = num_cached_blocks + offset + i
|
||||
assert blk.block_hash is None
|
||||
|
||||
if i + offset < len(new_block_hashes):
|
||||
# The block hash may already be computed in
|
||||
# "get_computed_blocks" if the tokens are not generated by
|
||||
# this request (either the prompt tokens or the previously
|
||||
# generated tokens with preemption). In this case we simply
|
||||
# reuse the block hash.
|
||||
block_hash = new_block_hashes[i + offset]
|
||||
else:
|
||||
# Otherwise compute the block hash and cache it in the request
|
||||
# in case it will be preempted in the future.
|
||||
start_token_idx = blk_idx * block_size
|
||||
end_token_idx = (blk_idx + 1) * block_size
|
||||
block_tokens = request.all_token_ids[
|
||||
start_token_idx:end_token_idx]
|
||||
assert len(block_tokens) == block_size, (
|
||||
f"Expected {block_size} tokens, got "
|
||||
f"{len(block_tokens)} at {blk_idx}th block for request "
|
||||
f"{request.request_id}({request})")
|
||||
|
||||
# Generate extra keys for multi-modal inputs. Note that since
|
||||
# we reach to this branch only when the block is completed with
|
||||
# generated tokens, we only need to consider the last mm input.
|
||||
extra_keys, _ = generate_block_hash_extra_keys(
|
||||
request, start_token_idx, end_token_idx, -1)
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(prev_block_hash_value,
|
||||
block_tokens, extra_keys)
|
||||
block_hashes.append(block_hash)
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
blk.block_hash = block_hash
|
||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
|
||||
def get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
|
||||
"""Get new blocks from the free block pool.
|
||||
|
||||
Note that we do not check block cache in this function.
|
||||
|
||||
Args:
|
||||
num_blocks: The number of blocks to allocate.
|
||||
|
||||
Returns:
|
||||
A list of new block.
|
||||
"""
|
||||
if num_blocks > self.get_num_free_blocks():
|
||||
raise ValueError(
|
||||
f"Cannot get {num_blocks} free blocks from the pool")
|
||||
|
||||
ret: List[KVCacheBlock] = []
|
||||
idx = 0
|
||||
while idx < num_blocks:
|
||||
# First allocate blocks.
|
||||
curr_block = self.free_block_queue.popleft()
|
||||
assert curr_block.ref_cnt == 0
|
||||
|
||||
# If the block is cached, evict it.
|
||||
if self.enable_caching:
|
||||
self._maybe_evict_cached_block(curr_block)
|
||||
|
||||
curr_block.incr_ref()
|
||||
ret.append(curr_block)
|
||||
idx += 1
|
||||
|
||||
return ret
|
||||
|
||||
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
|
||||
"""
|
||||
If a block is cached in `cached_block_hash_to_block`, we reset its hash
|
||||
metadata and evict it from the cache.
|
||||
|
||||
Args:
|
||||
block: The block to evict.
|
||||
|
||||
Returns:
|
||||
True if the block is evicted, False otherwise.
|
||||
"""
|
||||
block_hash = block.block_hash
|
||||
if block_hash and block_hash in self.cached_block_hash_to_block:
|
||||
block.reset_hash()
|
||||
del self.cached_block_hash_to_block[block_hash][block.block_id]
|
||||
|
||||
if len(self.cached_block_hash_to_block[block_hash]) == 0:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def touch(self, blocks: List[KVCacheBlock]) -> None:
|
||||
"""Touch a block increases its reference count by 1, and may remove
|
||||
the block from the free queue. This is used when a block is hit by
|
||||
another request with the same prefix.
|
||||
|
||||
Args:
|
||||
blocks: A list of blocks to touch.
|
||||
"""
|
||||
for block in blocks:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
|
||||
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
||||
"""Free a list of blocks. The blocks should be ordered by their
|
||||
eviction priority, where the first block will be evicted first.
|
||||
|
||||
Args:
|
||||
ordered_blocks: A list of blocks to free ordered by their eviction
|
||||
priority.
|
||||
"""
|
||||
for block in ordered_blocks:
|
||||
block.decr_ref()
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
|
||||
if num_used_blocks > 0:
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet", num_used_blocks)
|
||||
return False
|
||||
|
||||
# Remove all hashes so that no new blocks will hit.
|
||||
self.cached_block_hash_to_block = defaultdict(dict)
|
||||
|
||||
# Remove all hashes from all blocks.
|
||||
for block in self.blocks:
|
||||
block.reset_hash()
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
return True
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Get the number of free blocks in the pool.
|
||||
|
||||
Returns:
|
||||
The number of free blocks.
|
||||
"""
|
||||
return self.free_block_queue.num_free_blocks
|
||||
|
||||
def get_usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)
|
@ -5,10 +5,8 @@ from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens,
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@ -49,26 +47,7 @@ class KVCacheManager:
|
||||
self.num_preallocate_tokens = num_preallocate_tokens
|
||||
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
|
||||
|
||||
# A Block pool of all kv-cache blocks.
|
||||
self.block_pool: List[KVCacheBlock] = [
|
||||
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
||||
]
|
||||
# Free block queue that constructs and manipulates a doubly linked
|
||||
# list of free blocks (including eviction candidates when caching is
|
||||
# enabled).
|
||||
self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool)
|
||||
|
||||
# {block_hash: {block ID: block}}. A cached block is
|
||||
# a full block with a block hash that can be used for prefix caching.
|
||||
# The cached block may be used by running requests or in the
|
||||
# free_block_queue that could potentially be evicted.
|
||||
# NOTE: We currently don't de-duplicate the blocks in the cache,
|
||||
# meaning that if a block becomes full and is cached, we don't check
|
||||
# if there is already an identical block in the cache. This is because
|
||||
# we want to make sure the allocated block IDs won't change so that
|
||||
# block tables are append-only.
|
||||
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
|
||||
int, KVCacheBlock]] = defaultdict(dict)
|
||||
self.block_pool = BlockPool(num_gpu_blocks, enable_caching)
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
@ -96,8 +75,7 @@ class KVCacheManager:
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
return 1.0 - (self.free_block_queue.num_free_blocks /
|
||||
self.num_gpu_blocks)
|
||||
return self.block_pool.get_usage()
|
||||
|
||||
def make_prefix_cache_stats(self) -> PrefixCacheStats:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
@ -139,7 +117,7 @@ class KVCacheManager:
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := self._get_cached_block(block_hash):
|
||||
if cached_block := self.block_pool.get_cached_block(block_hash):
|
||||
computed_blocks.append(cached_block)
|
||||
else:
|
||||
break
|
||||
@ -204,14 +182,14 @@ class KVCacheManager:
|
||||
# when allocating this request.
|
||||
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
|
||||
if blk.ref_cnt == 0)
|
||||
if (num_new_blocks > self.free_block_queue.num_free_blocks -
|
||||
if (num_new_blocks > self.block_pool.get_num_free_blocks() -
|
||||
num_evictable_computed_blocks):
|
||||
# Cannot allocate new blocks
|
||||
return None
|
||||
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
if self.enable_caching:
|
||||
self._touch(new_computed_blocks)
|
||||
self.block_pool.touch(new_computed_blocks)
|
||||
else:
|
||||
assert not new_computed_blocks, (
|
||||
"Computed blocks should be empty when "
|
||||
@ -231,7 +209,7 @@ class KVCacheManager:
|
||||
# preallocated blocks.
|
||||
num_new_blocks = min(
|
||||
num_new_blocks + self.num_preallocate_blocks,
|
||||
self.free_block_queue.num_free_blocks,
|
||||
self.block_pool.get_num_free_blocks(),
|
||||
# Should not exceed the maximum number of blocks per request.
|
||||
# This is especially because the block table has the shape
|
||||
# [..., max_num_blocks_per_req].
|
||||
@ -240,29 +218,30 @@ class KVCacheManager:
|
||||
assert num_new_blocks > 0
|
||||
|
||||
# Concatenate the computed block IDs and the new block IDs.
|
||||
new_blocks = self._get_new_blocks(num_new_blocks)
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
|
||||
if not self.enable_caching:
|
||||
return new_blocks
|
||||
|
||||
# FIXME: `num_cached_blocks` is not correct when the prefix cache
|
||||
# of a new request is hit.
|
||||
num_cached_blocks = self.num_cached_block[request.request_id]
|
||||
# Speculated tokens might be rejected in the future, so we does
|
||||
# not cache any speculated tokens. We only cache blocks with
|
||||
# generated (accepted) tokens.
|
||||
num_full_blocks_after_append = (num_computed_tokens + num_tokens - len(
|
||||
request.spec_token_ids)) // self.block_size
|
||||
new_full_blocks = req_blocks[
|
||||
num_cached_blocks:num_full_blocks_after_append]
|
||||
|
||||
if new_full_blocks:
|
||||
self._cache_full_blocks(
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
blk_start_idx=num_cached_blocks,
|
||||
# The new full blocks are the full blocks that are not computed.
|
||||
full_blocks=new_full_blocks,
|
||||
prev_block=(req_blocks[num_cached_blocks -
|
||||
1] if num_cached_blocks > 0 else None))
|
||||
blocks=req_blocks,
|
||||
block_hashes=self.req_to_block_hashes[request.request_id],
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks_after_append,
|
||||
block_size=self.block_size,
|
||||
)
|
||||
|
||||
self.num_cached_block[
|
||||
request.request_id] = num_full_blocks_after_append
|
||||
return new_blocks
|
||||
@ -283,11 +262,7 @@ class KVCacheManager:
|
||||
# freed first.
|
||||
ordered_blocks = reversed(blocks)
|
||||
|
||||
for block in ordered_blocks:
|
||||
block.decr_ref()
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
self.block_pool.free_blocks(ordered_blocks)
|
||||
self.num_cached_block.pop(request.request_id, None)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
@ -299,25 +274,10 @@ class KVCacheManager:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = (self.num_gpu_blocks -
|
||||
self.free_block_queue.num_free_blocks)
|
||||
if num_used_blocks > 0:
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet", num_used_blocks)
|
||||
return False
|
||||
|
||||
# Remove all hashes so that no new blocks will hit.
|
||||
self.cached_block_hash_to_block = defaultdict(dict)
|
||||
|
||||
# Remove all hashes from all blocks.
|
||||
for block in self.block_pool:
|
||||
block.reset_hash()
|
||||
|
||||
if self.block_pool.reset_prefix_cache():
|
||||
self.prefix_cache_stats.reset = True
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self,
|
||||
@ -367,177 +327,6 @@ class KVCacheManager:
|
||||
break
|
||||
return num_common_blocks
|
||||
|
||||
def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
|
||||
"""Get new blocks from the free block pool.
|
||||
|
||||
Note that we do not check block cache in this function.
|
||||
|
||||
Args:
|
||||
num_blocks: The number of blocks to allocate.
|
||||
|
||||
Returns:
|
||||
A list of new block.
|
||||
"""
|
||||
if num_blocks > self.free_block_queue.num_free_blocks:
|
||||
raise ValueError(
|
||||
f"Cannot get {num_blocks} free blocks from the pool")
|
||||
|
||||
ret: List[KVCacheBlock] = []
|
||||
idx = 0
|
||||
while idx < num_blocks:
|
||||
# First allocate blocks.
|
||||
curr_block = self.free_block_queue.popleft()
|
||||
assert curr_block.ref_cnt == 0
|
||||
|
||||
# If the block is cached, evict it.
|
||||
if self.enable_caching:
|
||||
self._maybe_evict_cached_block(curr_block)
|
||||
|
||||
curr_block.incr_ref()
|
||||
ret.append(curr_block)
|
||||
idx += 1
|
||||
|
||||
return ret
|
||||
|
||||
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
|
||||
"""
|
||||
If a block is cached in `cached_block_hash_to_block`, we reset its hash
|
||||
metadata and evict it from the cache.
|
||||
|
||||
Args:
|
||||
block: The block to evict.
|
||||
|
||||
Returns:
|
||||
True if the block is evicted, False otherwise.
|
||||
"""
|
||||
block_hash = block.block_hash
|
||||
if block_hash and block_hash in self.cached_block_hash_to_block:
|
||||
block.reset_hash()
|
||||
del self.cached_block_hash_to_block[block_hash][block.block_id]
|
||||
|
||||
if len(self.cached_block_hash_to_block[block_hash]) == 0:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_cached_block(self,
|
||||
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
||||
"""Get a cached block by the block hash, or None if cache miss.
|
||||
If there are duplicated blocks, we return the first block in the cache.
|
||||
|
||||
Args:
|
||||
block_hash: The hash value of the block.
|
||||
|
||||
Returns:
|
||||
The cached block if it exists, or None.
|
||||
"""
|
||||
if block_hash in self.cached_block_hash_to_block:
|
||||
first_block_id = list(
|
||||
self.cached_block_hash_to_block[block_hash].keys())[0]
|
||||
return self.cached_block_hash_to_block[block_hash][first_block_id]
|
||||
return None
|
||||
|
||||
def _touch(self, blocks: List[KVCacheBlock]) -> None:
|
||||
"""Touch a block increases its reference count by 1, and may remove
|
||||
the block from the free queue. This is used when a block is hit by
|
||||
another request with the same prefix.
|
||||
|
||||
Args:
|
||||
blocks: A list of blocks to touch.
|
||||
"""
|
||||
for block in blocks:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
|
||||
def _cache_full_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
blk_start_idx: int,
|
||||
full_blocks: List[KVCacheBlock],
|
||||
prev_block: Optional[KVCacheBlock],
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
|
||||
This function takes a list of blocks that will have their block hash
|
||||
metadata to be updated and cached. Given a request, it computes the
|
||||
block hashes for the blocks starting from `blk_start_idx` to the end
|
||||
of the request's full blocks, updating the metadata for each block
|
||||
and caching them in the `cached_block_hash_to_block`.
|
||||
|
||||
Args:
|
||||
request: The request to cache the blocks.
|
||||
blk_start_idx: The index of the first block in the request's blocks
|
||||
to cache.
|
||||
full_blocks: The list of blocks to update hash metadata.
|
||||
prev_block: The previous block in the chain.
|
||||
"""
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
num_cached_block_hashes = len(block_hashes)
|
||||
|
||||
# Update the new blocks with the block hashes through the chain.
|
||||
prev_block_hash_value = None
|
||||
if prev_block is not None:
|
||||
# Previous block must have a block hash because it must be
|
||||
# a full, cached block.
|
||||
assert prev_block.block_hash is not None
|
||||
prev_block_hash_value = prev_block.block_hash.hash_value
|
||||
|
||||
# Find the first uncached block. This case should only happen when
|
||||
# speculative decoding is used.
|
||||
offset = 0
|
||||
for blk in full_blocks:
|
||||
if blk.block_hash is None:
|
||||
break
|
||||
else:
|
||||
prev_block_hash_value = blk.block_hash.hash_value
|
||||
offset += 1
|
||||
else:
|
||||
# All blocks are cached.
|
||||
return
|
||||
|
||||
for i, blk in enumerate(full_blocks[offset:]):
|
||||
blk_idx = blk_start_idx + offset + i
|
||||
assert blk.block_hash is None
|
||||
|
||||
if blk_idx < num_cached_block_hashes:
|
||||
# The block hash may already be computed in
|
||||
# "get_computed_blocks" if the tokens are not generated by
|
||||
# this request (either the prompt tokens or the previously
|
||||
# generated tokens with preemption). In this case we simply
|
||||
# reuse the block hash.
|
||||
block_hash = block_hashes[blk_idx]
|
||||
else:
|
||||
# Otherwise compute the block hash and cache it in the request
|
||||
# in case it will be preempted in the future.
|
||||
start_token_idx = blk_idx * self.block_size
|
||||
end_token_idx = (blk_idx + 1) * self.block_size
|
||||
block_tokens = request.all_token_ids[
|
||||
start_token_idx:end_token_idx]
|
||||
assert len(block_tokens) == self.block_size, (
|
||||
f"Expected {self.block_size} tokens, got "
|
||||
f"{len(block_tokens)} at {blk_idx}th block for request "
|
||||
f"{request.request_id}({request})")
|
||||
|
||||
# Generate extra keys for multi-modal inputs. Note that since
|
||||
# we reach to this branch only when the block is completed with
|
||||
# generated tokens, we only need to consider the last mm input.
|
||||
extra_keys, _ = generate_block_hash_extra_keys(
|
||||
request, start_token_idx, end_token_idx, -1)
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(prev_block_hash_value,
|
||||
block_tokens, extra_keys)
|
||||
block_hashes.append(block_hash)
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
blk.block_hash = block_hash
|
||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
|
||||
def free_block_hashes(self, request: Request) -> None:
|
||||
"""Discard the block hashes for the request.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user