[v1] Move block pool operations to a separate class (#13973)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Chen Zhang 2025-03-01 04:53:31 +08:00 committed by GitHub
parent b526ca6726
commit 28943d36ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 359 additions and 276 deletions

View File

@ -1,12 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""
from typing import List
import pytest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)
def make_request(request_id,
@ -62,14 +66,14 @@ def test_prefill():
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool[block_id].block_hash is None
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
@ -86,20 +90,21 @@ def test_prefill():
assert block.ref_cnt == 2
# At this point, we should have 3 free blocks left.
assert manager.free_block_queue.num_free_blocks == 3
assert manager.block_pool.free_block_queue.num_free_blocks == 3
manager.free(req0)
manager.free(req1)
# All blocks should be available.
assert manager.free_block_queue.num_free_blocks == 10
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
# Cache hit in the common prefix when the original block is already free.
@ -116,12 +121,14 @@ def test_prefill():
# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
assert manager.free_block_queue.num_free_blocks == 5
assert manager.block_pool.free_block_queue.num_free_blocks == 5
assert all([
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks()
b.ref_cnt == 0
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
])
assert len([b
for b in manager.free_block_queue.get_all_free_blocks()]) == 5
assert len([
b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
]) == 5
manager.free(req2)
@ -133,9 +140,9 @@ def test_prefill():
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0
assert manager.free_block_queue.free_list_head is None
assert manager.free_block_queue.free_list_tail is None
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
def test_decode():
@ -219,13 +226,14 @@ def test_evict():
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
assert manager.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.num_free_blocks == 0
manager.free(req0)
manager.free(req1)
assert manager.free_block_queue.num_free_blocks == 10
assert manager.block_pool.free_block_queue.num_free_blocks == 10
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
# Touch the first 2 blocks.
@ -235,7 +243,7 @@ def test_evict():
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6
assert manager.block_pool.free_block_queue.num_free_blocks == 6
def test_hash_block_correct_reuse():
@ -274,7 +282,7 @@ def test_hash_block_correct_reuse():
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1
assert manager.block_pool[blocks[0].block_id].block_hash is None
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None
def test_computed_blocks_not_evicted():
@ -413,13 +421,9 @@ def test_cache_blocks():
function of KVCacheManager.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
block_pool = BlockPool(
num_gpu_blocks=5,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
# Req:
# Block 0: [0, 1, 2, 3]
@ -430,26 +434,31 @@ def test_cache_blocks():
# Test that blocks are cached correctly for 2 full blocks from the start.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
block_hashes: List[BlockHashType] = []
manager._cache_full_blocks(
block_pool.cache_full_blocks(
request=req,
blk_start_idx=0,
full_blocks=blocks,
prev_block=None,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
)
assert len(manager.cached_block_hash_to_block) == 2
assert len(block_pool.cached_block_hash_to_block) == 2
assert all([block.block_hash is not None for block in blocks])
# Test that blocks that don't start from the beginning are cached correctly.
blocks = [KVCacheBlock(block_id=2)]
manager._cache_full_blocks(
blocks += [KVCacheBlock(block_id=2)]
block_pool.cache_full_blocks(
request=req,
blk_start_idx=2,
full_blocks=blocks,
prev_block=None,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=2,
num_full_blocks=3,
block_size=block_size,
)
assert len(manager.cached_block_hash_to_block) == 3
assert len(block_pool.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None
@ -580,7 +589,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.free_block_queue.num_free_blocks == 5
assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1
@ -621,12 +630,12 @@ def test_reset_prefix_cache():
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
assert manager.cached_block_hash_to_block
assert manager.block_pool.cached_block_hash_to_block
# Free the blocks.
manager.free(req0)
manager.free(req1)
assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])
assert not manager.block_pool.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool.blocks])

285
vllm/v1/core/block_pool.py Normal file
View File

@ -0,0 +1,285 @@
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockPool:
"""BlockPool that manages KVCacheBlocks.
It provides methods to allocate, free and cache the kv cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
by their block hash.
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
"""
def __init__(self, num_gpu_blocks: int, enable_caching: bool):
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
# All kv-cache blocks.
self.blocks: List[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
# {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
# The cached block may be used by running requests or in the
# free_block_queue that could potentially be evicted.
# NOTE: We currently don't de-duplicate the blocks in the cache,
# meaning that if a block becomes full and is cached, we don't check
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
int, KVCacheBlock]] = defaultdict(dict)
def get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
Returns:
The cached block if it exists, or None.
"""
if block_hash in self.cached_block_hash_to_block:
first_block_id = list(
self.cached_block_hash_to_block[block_hash].keys())[0]
return self.cached_block_hash_to_block[block_hash][first_block_id]
return None
def cache_full_blocks(
self,
request: Request,
blocks: List[KVCacheBlock],
block_hashes: List[BlockHashType],
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `num_cached_blocks` to
`num_full_blocks`, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
Args:
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
"""
if num_cached_blocks == num_full_blocks:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(block_hashes) >= num_cached_blocks
new_block_hashes = block_hashes[num_cached_blocks:]
# Update the new blocks with the block hashes through the chain.
if num_cached_blocks == 0:
prev_block_hash_value = None
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value
# Find the first uncached block.
# FIXME: num_cached_blocks should be corrected by the caller
# so this should never happen.
offset = 0
for blk in new_full_blocks:
if blk.block_hash is None:
break
else:
prev_block_hash_value = blk.block_hash.hash_value
offset += 1
else:
# All blocks are cached.
return
for i, blk in enumerate(new_full_blocks[offset:]):
blk_idx = num_cached_blocks + offset + i
assert blk.block_hash is None
if i + offset < len(new_block_hashes):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = new_block_hashes[i + offset]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == block_size, (
f"Expected {block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash_value = block_hash.hash_value
def get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
Returns:
A list of new block.
"""
if num_blocks > self.get_num_free_blocks():
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
ret: List[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# If the block is cached, evict it.
if self.enable_caching:
self._maybe_evict_cached_block(curr_block)
curr_block.incr_ref()
ret.append(curr_block)
idx += 1
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.
Args:
block: The block to evict.
Returns:
True if the block is evicted, False otherwise.
"""
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
block.reset_hash()
del self.cached_block_hash_to_block[block_hash][block.block_id]
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]
return True
return False
def touch(self, blocks: List[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Args:
blocks: A list of blocks to touch.
"""
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0:
self.free_block_queue.remove(block)
block.incr_ref()
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their
eviction priority, where the first block will be evicted first.
Args:
ordered_blocks: A list of blocks to free ordered by their eviction
priority.
"""
for block in ordered_blocks:
block.decr_ref()
if block.ref_cnt == 0:
self.free_block_queue.append(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = defaultdict(dict)
# Remove all hashes from all blocks.
for block in self.blocks:
block.reset_hash()
logger.info("Successfully reset prefix cache")
return True
def get_num_free_blocks(self) -> int:
"""Get the number of free blocks in the pool.
Returns:
The number of free blocks.
"""
return self.free_block_queue.num_free_blocks
def get_usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)

View File

@ -5,10 +5,8 @@ from typing import DefaultDict, Dict, Iterable, List, Optional, Tuple
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens,
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_request_tokens)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
@ -49,26 +47,7 @@ class KVCacheManager:
self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
# A Block pool of all kv-cache blocks.
self.block_pool: List[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool)
# {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
# The cached block may be used by running requests or in the
# free_block_queue that could potentially be evicted.
# NOTE: We currently don't de-duplicate the blocks in the cache,
# meaning that if a block becomes full and is cached, we don't check
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
int, KVCacheBlock]] = defaultdict(dict)
self.block_pool = BlockPool(num_gpu_blocks, enable_caching)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
@ -96,8 +75,7 @@ class KVCacheManager:
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return 1.0 - (self.free_block_queue.num_free_blocks /
self.num_gpu_blocks)
return self.block_pool.get_usage()
def make_prefix_cache_stats(self) -> PrefixCacheStats:
"""Get (and reset) the prefix cache stats.
@ -139,7 +117,7 @@ class KVCacheManager:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self._get_cached_block(block_hash):
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
@ -204,14 +182,14 @@ class KVCacheManager:
# when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
if blk.ref_cnt == 0)
if (num_new_blocks > self.free_block_queue.num_free_blocks -
if (num_new_blocks > self.block_pool.get_num_free_blocks() -
num_evictable_computed_blocks):
# Cannot allocate new blocks
return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self._touch(new_computed_blocks)
self.block_pool.touch(new_computed_blocks)
else:
assert not new_computed_blocks, (
"Computed blocks should be empty when "
@ -231,7 +209,7 @@ class KVCacheManager:
# preallocated blocks.
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
self.block_pool.get_num_free_blocks(),
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
# [..., max_num_blocks_per_req].
@ -240,29 +218,30 @@ class KVCacheManager:
assert num_new_blocks > 0
# Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks)
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
if not self.enable_caching:
return new_blocks
# FIXME: `num_cached_blocks` is not correct when the prefix cache
# of a new request is hit.
num_cached_blocks = self.num_cached_block[request.request_id]
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
num_full_blocks_after_append = (num_computed_tokens + num_tokens - len(
request.spec_token_ids)) // self.block_size
new_full_blocks = req_blocks[
num_cached_blocks:num_full_blocks_after_append]
if new_full_blocks:
self._cache_full_blocks(
self.block_pool.cache_full_blocks(
request=request,
blk_start_idx=num_cached_blocks,
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=(req_blocks[num_cached_blocks -
1] if num_cached_blocks > 0 else None))
blocks=req_blocks,
block_hashes=self.req_to_block_hashes[request.request_id],
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks_after_append,
block_size=self.block_size,
)
self.num_cached_block[
request.request_id] = num_full_blocks_after_append
return new_blocks
@ -283,11 +262,7 @@ class KVCacheManager:
# freed first.
ordered_blocks = reversed(blocks)
for block in ordered_blocks:
block.decr_ref()
if block.ref_cnt == 0:
self.free_block_queue.append(block)
self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request.request_id, None)
def reset_prefix_cache(self) -> bool:
@ -299,25 +274,10 @@ class KVCacheManager:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks -
self.free_block_queue.num_free_blocks)
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = defaultdict(dict)
# Remove all hashes from all blocks.
for block in self.block_pool:
block.reset_hash()
if self.block_pool.reset_prefix_cache():
self.prefix_cache_stats.reset = True
logger.info("Successfully reset prefix cache")
return True
return False
def get_num_common_prefix_blocks(
self,
@ -367,177 +327,6 @@ class KVCacheManager:
break
return num_common_blocks
def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
Returns:
A list of new block.
"""
if num_blocks > self.free_block_queue.num_free_blocks:
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
ret: List[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# If the block is cached, evict it.
if self.enable_caching:
self._maybe_evict_cached_block(curr_block)
curr_block.incr_ref()
ret.append(curr_block)
idx += 1
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.
Args:
block: The block to evict.
Returns:
True if the block is evicted, False otherwise.
"""
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
block.reset_hash()
del self.cached_block_hash_to_block[block_hash][block.block_id]
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]
return True
return False
def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
Returns:
The cached block if it exists, or None.
"""
if block_hash in self.cached_block_hash_to_block:
first_block_id = list(
self.cached_block_hash_to_block[block_hash].keys())[0]
return self.cached_block_hash_to_block[block_hash][first_block_id]
return None
def _touch(self, blocks: List[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Args:
blocks: A list of blocks to touch.
"""
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0:
self.free_block_queue.remove(block)
block.incr_ref()
def _cache_full_blocks(
self,
request: Request,
blk_start_idx: int,
full_blocks: List[KVCacheBlock],
prev_block: Optional[KVCacheBlock],
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `blk_start_idx` to the end
of the request's full blocks, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
Args:
request: The request to cache the blocks.
blk_start_idx: The index of the first block in the request's blocks
to cache.
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
block_hashes = self.req_to_block_hashes[request.request_id]
num_cached_block_hashes = len(block_hashes)
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value = None
if prev_block is not None:
# Previous block must have a block hash because it must be
# a full, cached block.
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value
# Find the first uncached block. This case should only happen when
# speculative decoding is used.
offset = 0
for blk in full_blocks:
if blk.block_hash is None:
break
else:
prev_block_hash_value = blk.block_hash.hash_value
offset += 1
else:
# All blocks are cached.
return
for i, blk in enumerate(full_blocks[offset:]):
blk_idx = blk_start_idx + offset + i
assert blk.block_hash is None
if blk_idx < num_cached_block_hashes:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = block_hashes[blk_idx]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx = blk_idx * self.block_size
end_token_idx = (blk_idx + 1) * self.block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash_value = block_hash.hash_value
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.