[V1] Prefix caching (take 2) (#9972)

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Cody Yu 2024-11-07 17:34:44 -08:00 committed by GitHub
parent 42b4f46b71
commit 201fc07730
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 770 additions and 65 deletions

View File

@ -118,7 +118,7 @@ def main(args):
random.seed(args.seed) random.seed(args.seed)
if args.dataset_path is not None: if args.dataset_path is not None:
print(f"Start to sample {args.num_prompts} prompts" print(f"Start to sample {args.num_prompts} prompts"
"from {args.dataset_path}") f"from {args.dataset_path}")
filtered_datasets = sample_requests( filtered_datasets = sample_requests(
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
@ -142,13 +142,6 @@ def main(args):
repeat_count=args.repeat_count, repeat_count=args.repeat_count,
sort=args.sort) sort=args.sort)
print("------warm up------")
test_prefix(
llm=llm,
prompts=prompts,
sampling_params=sampling_params,
)
print("------start generating------") print("------start generating------")
test_prefix( test_prefix(
llm=llm, llm=llm,

View File

@ -0,0 +1,219 @@
"""Compare the with and without prefix caching."""
from vllm.inputs import DecoderOnlyInputs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import hash_block_tokens
def make_request(request_id, prompt_token_ids):
return Request(
request_id=request_id,
inputs=DecoderOnlyInputs(prompt_token_ids=prompt_token_ids),
sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100,
arrival_time=0,
lora_request=None,
)
def test_prefill():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=16,
)
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_hash = hash_block_tokens(parent_block_hash,
manager.block_pool[block_id].token_ids)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
block_id + 1)
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16)
parent_block_hash = block_hash
# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool[block_id].block_hash is None
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 0
if block_id == 3:
assert manager.block_pool[block_id].token_ids == [3] * 7
else:
assert not manager.block_pool[block_id].token_ids
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req1)
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
for block in computed_blocks:
assert block.ref_cnt == 2
# At this point, we should have 3 free blocks left.
assert manager.free_block_queue.num_free_blocks == 3
manager.free(req0)
manager.free(req1)
# All blocks should be available.
assert manager.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_block = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_block] == [0, 1, 2]
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]
# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
assert manager.free_block_queue.num_free_blocks == 5
assert all([
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks()
])
assert len([b
for b in manager.free_block_queue.get_all_free_blocks()]) == 5
manager.free(req2)
# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 9))
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0
assert manager.free_block_queue.free_list_head is None
assert manager.free_block_queue.free_list_tail is None
def test_decode():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=16,
)
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 11
# Append slots without allocating a new block, but start using the
# preallocated block.
req0.num_computed_tokens = 59
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(5 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 16
assert len(manager.block_pool[4].token_ids) == 10
# Append slots with allocating a new block.
req0.num_computed_tokens = 74
# 6 tokens to fill the previous block, and 10 tokens to fill
# the preallocated block.
for _ in range(6 + 11):
req0.append_output_token_ids(12)
new_blocks = manager.append_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2
assert len(manager.block_pool[4].token_ids) == 16
assert len(manager.block_pool[5].token_ids) == 11
assert len(manager.block_pool[6].token_ids) == 0
def test_evict():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=16,
)
last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
assert manager.free_block_queue.num_free_blocks == 0
manager.free(req0)
manager.free(req1)
assert manager.free_block_queue.num_free_blocks == 10
assert [
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1]
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6

View File

@ -1,9 +1,11 @@
from collections import defaultdict
from typing import Dict, List, Optional from typing import Dict, List, Optional
import numpy as np
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, hash_block_tokens,
hash_request_tokens)
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
@ -36,73 +38,359 @@ class KVCacheManager:
self.num_preallocate_tokens = num_preallocate_tokens self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size) self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
self.free_block_ids = list(range(num_gpu_blocks)) # A Block pool of all kv-cache blocks.
self.req_to_block_ids: Dict[str, List[int]] = {} self.block_pool: List[KVCacheBlock] = [
self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32) KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.block_pool)
def get_computed_blocks(self, request: Request) -> List[int]: # {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
# The cached block may be used by running requests or in the
# free_block_queue that could potentially be evicted.
# NOTE: We currently don't de-duplicate the blocks in the cache,
# meaning that if a block becomes full and is cached, we don't check
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: Dict[BlockHashType, Dict[
int, KVCacheBlock]] = defaultdict(dict)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: Dict[str, List[KVCacheBlock]] = {}
def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
Args:
request: The request to get the computed blocks.
Returns:
A list of blocks that are computed for the request.
"""
if not self.enable_caching: if not self.enable_caching:
# No prefix caching. # Prefix caching is disabled.
return []
# TODO(woosuk): Implement hash-based caching.
return [] return []
computed_blocks = []
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self._get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
return computed_blocks
def append_slots( def append_slots(
self, self,
request: Request, request: Request,
num_tokens: int, num_tokens: int,
) -> Optional[List[int]]: ) -> Optional[List[KVCacheBlock]]:
"""Append slots to the block table of the request.
We first append slots to already allocated blocks. If the allocated
blocks are not enough, we allocate new blocks.
Args:
request: The request to append slots.
num_tokens: The number of tokens to append.
Returns:
A list of new blocks if new blocks are allocated, or None
if new blocks are required but cannot be allocated.
"""
num_required_blocks = cdiv(request.num_computed_tokens + num_tokens, num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
self.block_size) self.block_size)
req_block_ids = self.req_to_block_ids[request.request_id] req_blocks = self.req_to_blocks[request.request_id]
if num_required_blocks <= len(req_block_ids):
# No new block is needed.
return []
num_new_blocks = num_required_blocks - len(req_block_ids) num_new_blocks = num_required_blocks - len(req_blocks)
num_free_blocks = len(self.free_block_ids) if num_new_blocks > self.free_block_queue.num_free_blocks:
if num_new_blocks > num_free_blocks: # Need to allocate new blocks due to insufficient pre-allocated
# Cannot allocate new blocks. # slots, but we cannot allocate new blocks due to the limit.
return None return None
# Allocate new blocks. # When caching is enabled, assign token IDs to already allocated blocks.
new_token_ids = None
parent_block = None
if self.enable_caching:
# Figure out the token IDs to add to the blocks.
new_token_ids = request.all_token_ids[
request.num_computed_tokens:request.num_computed_tokens +
num_tokens]
# Find the last full block index.
# TODO: This may be optimized by calculating the computed tokens.
last_full_block_idx = len(req_blocks) - 1
while (last_full_block_idx >= 0
and req_blocks[last_full_block_idx].block_hash is None):
last_full_block_idx -= 1
parent_block = (req_blocks[last_full_block_idx]
if last_full_block_idx >= 0 else None)
token_id_idx = self._add_token_ids_to_blocks(
blocks=req_blocks[last_full_block_idx + 1:],
token_ids=new_token_ids,
parent_block=parent_block)
new_token_ids = new_token_ids[token_id_idx:]
parent_block = req_blocks[-1]
# No new block is needed. When caching is enabled, we make sure
# token_id_idx is equal to len(new_token_ids), meaning that all tokens
# are added to allocated blocks.
if num_required_blocks <= len(req_blocks):
assert not self.enable_caching or token_id_idx == num_tokens, \
f"{token_id_idx=} != {num_tokens=}"
return []
# Allocate new blocks considering preallocated blocks, and
# add token IDs to them if caching is enabled.
num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks, num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks,
num_free_blocks) self.free_block_queue.num_free_blocks)
new_block_ids = self._get_new_blocks(num_new_blocks) new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
req_block_ids.extend(new_block_ids) parent_block)
self.ref_cnts[new_block_ids] += 1 req_blocks.extend(new_blocks)
return new_block_ids return new_blocks
def allocate_slots( def allocate_slots(
self, self,
request: Request, request: Request,
num_tokens: int, num_tokens: int,
computed_block_ids: List[int], computed_blocks: List[KVCacheBlock],
) -> Optional[List[int]]: ) -> Optional[List[KVCacheBlock]]:
"""Allocate slots for a new request.
Args:
request: The request to allocate slots.
num_tokens: The number of tokens to allocate. Note that this does
not include the tokens that have already been computed.
computed_blocks: The blocks that have already been computed.
Returns:
A list of new allocated blocks.
"""
if num_tokens == 0:
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])
num_required_blocks = cdiv(num_tokens, self.block_size) num_required_blocks = cdiv(num_tokens, self.block_size)
num_free_blocks = len(self.free_block_ids) if (num_required_blocks > self.free_block_queue.num_free_blocks -
if num_required_blocks > num_free_blocks: num_evictable_computed_blocks):
# Cannot allocate new blocks. # Cannot allocate new blocks.
return None return None
num_new_blocks = min(num_required_blocks + self.num_preallocate_blocks, # Determine the number of new blocks to allocate considering
num_free_blocks) # preallocated blocks.
new_block_ids = self._get_new_blocks(num_new_blocks) num_new_blocks = min(
block_ids = computed_block_ids + new_block_ids num_required_blocks + self.num_preallocate_blocks,
self.req_to_block_ids[request.request_id] = block_ids self.free_block_queue.num_free_blocks -
self.ref_cnts[block_ids] += 1 num_evictable_computed_blocks)
return new_block_ids
num_computed_tokens = len(computed_blocks) * self.block_size
# When caching is enabled, get the new token IDs and the parent block
# ID to generate cache keys.
new_token_ids = None
parent_block = None
if self.enable_caching:
# Touch the computed blocks to make sure they won't be evicted.
self._touch(computed_blocks)
# Get the token IDs for the blocks being allocated for hashing.
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_tokens]
if not new_token_ids:
raise RuntimeError(
"Failed to infer the token IDs for allocation. "
f"#all_tokens={len(request.all_token_ids)} < "
f"#computed_tokens={num_computed_tokens}")
# Get the parent block ID to construct the block chain.
parent_block = computed_blocks[-1] if computed_blocks else None
new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
parent_block)
# Concatenate the computed block IDs and the new block IDs.
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks
return new_blocks
def free(self, request: Request) -> None: def free(self, request: Request) -> None:
block_ids = self.req_to_block_ids.pop(request.request_id) """Free the blocks allocated for the request.
self.ref_cnts[block_ids] -= 1 When caching is enabled, we free the blocks in reverse order so that
for block_id in block_ids: the tail blocks are evicted first.
ref_cnt = self.ref_cnts[block_id]
if ref_cnt == 0:
self.free_block_ids.append(block_id)
def _get_new_blocks(self, num_blocks: int) -> List[int]: Args:
assert num_blocks <= len(self.free_block_ids) request: The request to free the blocks.
new_block_ids = self.free_block_ids[-num_blocks:] """
self.free_block_ids = self.free_block_ids[:-num_blocks] blocks = self.req_to_blocks.pop(request.request_id)
return new_block_ids if self.enable_caching:
# Free blocks in reverse order so that the tail blocks are
# freed first.
blocks = reversed(blocks)
for block in blocks:
block.ref_cnt -= 1
if block.ref_cnt == 0:
self.free_block_queue.append(block)
def _get_new_blocks(
self,
num_blocks: int,
token_ids: Optional[List[int]] = None,
parent_block: Optional[int] = None) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool, and add token IDs to
allocated blocks if caching is enabled.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
token_ids: The token IDs in the blocks. None if caching is disabled.
parent_block: The parent block. Used to include block chain
in the block hash.
Returns:
A list of new block.
"""
if num_blocks > self.free_block_queue.num_free_blocks:
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
# First allocate blocks.
ret: List[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# Evict blocks from the cache.
if self.enable_caching:
block_hash = curr_block.block_hash
if (block_hash is not None
and block_hash in self.cached_block_hash_to_block):
if len(self.cached_block_hash_to_block[block_hash]) == 1:
del self.cached_block_hash_to_block[block_hash]
else:
del self.cached_block_hash_to_block[block_hash][
curr_block.block_id]
curr_block.reset()
curr_block.ref_cnt = 1
ret.append(curr_block)
idx += 1
# Then assign token IDs to the allocated blocks.
if self.enable_caching:
assert token_ids is not None
token_id_idx = self._add_token_ids_to_blocks(
blocks=ret, token_ids=token_ids, parent_block=parent_block)
assert token_id_idx == len(token_ids)
return ret
def _cache_full_block(self,
block: KVCacheBlock,
parent_block: Optional[KVCacheBlock] = None) -> None:
"""Cache a full block for prefix caching.
Args:
block: The block to cache.
parent_block: The parent block. None if this is the first block.
"""
parent_block_hash = (parent_block.block_hash
if parent_block is not None else None)
assert len(block.token_ids) == self.block_size
block.token_ids = tuple(block.token_ids)
block_hash = hash_block_tokens(parent_block_hash, block.token_ids)
block.block_hash = block_hash
block.num_hashed_tokens = self.block_size + (
parent_block.num_hashed_tokens if parent_block is not None else 0)
self.cached_block_hash_to_block[block_hash][block.block_id] = block
def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
Returns:
The cached block if it exists, or None.
"""
if block_hash in self.cached_block_hash_to_block:
first_block_id = list(
self.cached_block_hash_to_block[block_hash].keys())[0]
return self.cached_block_hash_to_block[block_hash][first_block_id]
return None
def _touch(self, blocks: List[KVCacheBlock]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Args:
blocks: A list of blocks to touch.
"""
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0:
self.free_block_queue.remove(block)
block.ref_cnt += 1
def _add_token_ids_to_blocks(
self,
blocks: List[KVCacheBlock],
token_ids: List[int],
parent_block: Optional[KVCacheBlock] = None) -> int:
"""Add token IDs to a list of allocated blocks.
If a block becomes full after adding token IDs, cache it.
Return the token ID index that has not been added to the blocks
if the blocks are not enough to hold all the token IDs.
Args:
blocks: A list of blocks to add token IDs.
token_ids: A list of token IDs to add.
parent_block: The parent block. None if this is the
first block.
Returns:
The starting token ID index that has not been added to the blocks
due to insufficient given blocks.
"""
token_id_start = 0
for curr_block in blocks:
# If all token IDs are added, then the rest of the blocks are
# preallocated blocks, so we only need to update the
# parent_block_id. FIXME
if token_id_start == len(token_ids):
continue
# Add token IDs to the empty slots in the block.
empty_slots = self.block_size - len(curr_block.token_ids)
token_id_end = min(token_id_start + empty_slots, len(token_ids))
curr_block.token_ids.extend(token_ids[token_id_start:token_id_end])
# Cache the block if it becomes full.
if len(curr_block.token_ids) == self.block_size:
self._cache_full_block(curr_block, parent_block)
parent_block = curr_block
token_id_start = token_id_end
return token_id_start

View File

@ -0,0 +1,194 @@
"""KV-Cache Utilities."""
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union
from vllm.logger import init_logger
logger = init_logger(__name__)
BlockHashType = Tuple[int, Tuple[int]]
@dataclass
class KVCacheBlock:
"""KV-cache block metadata."""
# Block ID, ranging from 0 to num_gpu_blocks - 1.
block_id: int
# Reference count.
ref_cnt: int = 0
# Token IDs in the block. When the block is full, the type of token_ids
# should be Tuple[int] for fast matching.
token_ids: Union[List[int], Tuple[int]] = field(default_factory=list)
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
block_hash: Optional[BlockHashType] = None
# The number of hashed tokens. More hashed tokens means the block
# is closer to the end of a prompt and more likely to be evicted.
num_hashed_tokens: int = 0
# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None
def reset(self):
"""Reset the block metadata."""
self.ref_cnt = 0
self.token_ids = []
self.block_hash = None
self.num_hashed_tokens = 0
class FreeKVCacheBlockQueue:
"""This class organizes a list of KVCacheBlock objects to a doubly linked
list of free blocks. We implement this class instead of using Python
builtin deque to support removing a block in the middle of the queue
in O(1) time. To close the performance gap to the builtin deque which is
implemented in C++, this class does not allocate any Python objects when
manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks.
The queue is ordered by block ID in the beginning. When a block is allocated
and then freed, it will be appended back with the eviction order:
1. The least recent used block is at the front (LRU).
2. If two blocks have the same last accessed time (allocated by the
same sequence), the one with more hash tokens (the tail of a block
chain) is at the front.
Note that we maintain this order by reversing the block order when free
blocks of a request. This operation is outside of this class.
Args:
blocks: A list of KVCacheBlock objects.
"""
def __init__(self, blocks: List[KVCacheBlock]) -> None:
self.num_free_blocks = len(blocks)
# Initialize the doubly linked list of free blocks.
self.free_list_head = blocks[0]
self.free_list_tail = blocks[-1]
for i in range(self.num_free_blocks):
if i > 0:
blocks[i].prev_free_block = blocks[i - 1]
if i < self.num_free_blocks - 1:
blocks[i].next_free_block = blocks[i + 1]
def popleft(self) -> KVCacheBlock:
"""Pop the first free block and reduce num_free_blocks by 1.
Returns:
The first free block.
"""
if not self.free_list_head:
raise ValueError("No free blocks available")
block = self.free_list_head
self.remove(block)
return block
def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
Args:
block: The block to remove.
"""
if block.prev_free_block is not None:
# Link the previous block to the next block.
block.prev_free_block.next_free_block = block.next_free_block
if block.next_free_block is not None:
# Link the next block to the previous block.
block.next_free_block.prev_free_block = block.prev_free_block
if block == self.free_list_head:
# Update the head if the block is the head.
self.free_list_head = block.next_free_block
if block == self.free_list_tail:
# Update the tail if the block is the tail.
self.free_list_tail = block.prev_free_block
# Remove the block from the linked list.
block.prev_free_block = block.next_free_block = None
self.num_free_blocks -= 1
def append(self, block: KVCacheBlock) -> None:
"""Put a block back into the free list and increase
num_free_blocks by 1.
Args:
block: The block to append.
"""
if self.free_list_tail is not None:
# Link the last block to the new block.
self.free_list_tail.next_free_block = block
block.prev_free_block = self.free_list_tail
self.free_list_tail = block
else:
# The free list is empty.
assert self.free_list_head is None
self.free_list_head = self.free_list_tail = block
block.next_free_block = None
self.num_free_blocks += 1
def get_all_free_blocks(self) -> List[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.
Returns:
A list of free blocks.
"""
ret = []
curr_block = self.free_list_head
while curr_block is not None:
ret.append(curr_block)
curr_block = curr_block.next_free_block
return ret
def hash_block_tokens(parent_block_hash: Optional[int],
curr_block_token_ids: Tuple[int]) -> BlockHashType:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
TODO: Support arbitrary metadata so that we could support more
features such as LoRA adapter.
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A tuple of token ids in the current
block. The current block is assumed to be full.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return (hash(
(parent_block_hash, *curr_block_token_ids)), curr_block_token_ids)
def hash_request_tokens(block_size: int,
token_ids: List[int]) -> List[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
token_ids: A sequence of token ids in the request.
Returns:
The list of computed hash values.
"""
ret = []
parent_block_hash = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = tuple(token_ids[start:end])
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
block_hash = hash_block_tokens(parent_block_hash, block_token_ids)
ret.append(block_hash)
parent_block_hash = block_hash
return ret

View File

@ -34,7 +34,7 @@ class Scheduler:
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
sliding_window=self.cache_config.sliding_window, sliding_window=self.cache_config.sliding_window,
enable_caching=True) enable_caching=self.cache_config.enable_prefix_caching)
self.block_size = self.cache_config.block_size self.block_size = self.cache_config.block_size
# Scheduling constraints. # Scheduling constraints.
@ -91,9 +91,9 @@ class Scheduler:
assert num_new_tokens > 0 assert num_new_tokens > 0
while True: while True:
new_block_ids = self.kv_cache_manager.append_slots( new_blocks = self.kv_cache_manager.append_slots(
request, num_new_tokens) request, num_new_tokens)
if new_block_ids is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
# Preempt the lowest-priority request. # Preempt the lowest-priority request.
preempted_req = self.running.pop() preempted_req = self.running.pop()
@ -110,7 +110,9 @@ class Scheduler:
# The request can be scheduled. # The request can be scheduled.
scheduled_running_reqs.append(request) scheduled_running_reqs.append(request)
req_to_new_block_ids[request.request_id] = new_block_ids req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
@ -126,22 +128,29 @@ class Scheduler:
request = self.waiting[0] request = self.waiting[0]
# Get already-cached tokens. # Get already-cached tokens.
computed_block_ids = self.kv_cache_manager.get_computed_blocks( computed_blocks = self.kv_cache_manager.get_computed_blocks(
request) request)
# NOTE(woosuk): Since incomplete blocks are not eligible for # NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of # sharing, `num_computed_tokens` is always a multiple of
# `block_size`. # `block_size`.
num_computed_tokens = len(computed_block_ids) * self.block_size num_computed_tokens = len(computed_blocks) * self.block_size
# Number of tokens to be scheduled. # Number of tokens to be scheduled.
# We use `request.num_tokens` instead of # We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed requests, # `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens. # which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# The happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last token.
num_computed_tokens -= 1
num_new_tokens = 1
computed_blocks.pop()
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0 assert num_new_tokens > 0
new_block_ids = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_block_ids) request, num_new_tokens, computed_blocks)
if new_block_ids is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
break break
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
@ -156,8 +165,9 @@ class Scheduler:
raise RuntimeError( raise RuntimeError(
f"Invalid request status: {request.status}") f"Invalid request status: {request.status}")
req_to_new_block_ids[request.request_id] = ( req_to_new_block_ids[request.request_id] = [
computed_block_ids + new_block_ids) b.block_id for b in computed_blocks + new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING

View File

@ -65,6 +65,7 @@ class LLMEngine:
elif usage_context == UsageContext.OPENAI_API_SERVER: elif usage_context == UsageContext.OPENAI_API_SERVER:
scheduler_config.max_num_seqs = 1024 scheduler_config.max_num_seqs = 1024
scheduler_config.max_num_batched_tokens = 2048 scheduler_config.max_num_batched_tokens = 2048
cache_config.enable_prefix_caching = True
logger.info( logger.info(
"Initializing an LLM engine (v%s) with config: " "Initializing an LLM engine (v%s) with config: "