220 lines
8.1 KiB
Python
220 lines
8.1 KiB
Python
"""Compare the with and without prefix caching."""
|
|
from vllm.inputs import token_inputs
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
|
from vllm.v1.core.kv_cache_utils import hash_block_tokens
|
|
|
|
|
|
def make_request(request_id, prompt_token_ids):
|
|
return Request(
|
|
request_id=request_id,
|
|
inputs=token_inputs(prompt_token_ids=prompt_token_ids),
|
|
sampling_params=SamplingParams(max_tokens=17),
|
|
eos_token_id=100,
|
|
arrival_time=0,
|
|
lora_request=None,
|
|
)
|
|
|
|
|
|
def test_prefill():
|
|
manager = KVCacheManager(
|
|
block_size=16,
|
|
num_gpu_blocks=10,
|
|
sliding_window=False,
|
|
enable_caching=True,
|
|
num_preallocate_tokens=16,
|
|
)
|
|
|
|
# Complete 3 blocks (48 tokens)
|
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
|
|
|
# Fully cache miss
|
|
# Incomplete 1 block (7 tokens)
|
|
unique_token_ids = [3] * 7
|
|
req0 = make_request("0", common_token_ids + unique_token_ids)
|
|
computed_blocks = manager.get_computed_blocks(req0)
|
|
assert not computed_blocks
|
|
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
|
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
|
|
|
# Check full block metadata
|
|
parent_block_hash = None
|
|
for block_id in (0, 1, 2):
|
|
block_hash = hash_block_tokens(parent_block_hash,
|
|
manager.block_pool[block_id].token_ids)
|
|
assert manager.block_pool[block_id].block_hash == block_hash
|
|
assert manager.block_pool[block_id].ref_cnt == 1
|
|
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
|
|
block_id + 1)
|
|
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16)
|
|
parent_block_hash = block_hash
|
|
|
|
# Check partial/preallocated block metadata
|
|
for block_id in (3, 4):
|
|
assert manager.block_pool[block_id].block_hash is None
|
|
assert manager.block_pool[block_id].ref_cnt == 1
|
|
assert manager.block_pool[block_id].num_hashed_tokens == 0
|
|
if block_id == 3:
|
|
assert manager.block_pool[block_id].token_ids == [3] * 7
|
|
else:
|
|
assert not manager.block_pool[block_id].token_ids
|
|
|
|
# Cache hit in the common prefix when the original block is still in use.
|
|
# Incomplete 1 block (5 tokens)
|
|
unique_token_ids = [3] * 5
|
|
req1 = make_request("1", common_token_ids + unique_token_ids)
|
|
computed_blocks = manager.get_computed_blocks(req1)
|
|
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
|
num_new_tokens = 53 - 3 * 16
|
|
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
|
assert [b.block_id for b in blocks] == [5, 6]
|
|
for block in computed_blocks:
|
|
assert block.ref_cnt == 2
|
|
|
|
# At this point, we should have 3 free blocks left.
|
|
assert manager.free_block_queue.num_free_blocks == 3
|
|
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
|
|
# All blocks should be available.
|
|
assert manager.free_block_queue.num_free_blocks == 10
|
|
# The order should be
|
|
# [unallocated (7, 8)]
|
|
# [unique_req0 (4, 3)]
|
|
# [unique_req1 (6, 5)]
|
|
# [common (2, 1, 0)]
|
|
assert [
|
|
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
|
|
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
|
|
|
|
# Cache hit in the common prefix when the original block is already free.
|
|
# Incomplete 1 block (6 tokens)
|
|
unique_token_ids = [3] * 6
|
|
req2 = make_request("2", common_token_ids + unique_token_ids)
|
|
computed_block = manager.get_computed_blocks(req2)
|
|
assert [b.block_id for b in computed_block] == [0, 1, 2]
|
|
num_new_tokens = 53 - 3 * 16
|
|
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
|
|
assert [b.block_id for b in blocks] == [7, 8]
|
|
|
|
# Although we only have 5 free blocks, we have 8 blocks in
|
|
# the free block queue due to lazy removal.
|
|
assert manager.free_block_queue.num_free_blocks == 5
|
|
assert all([
|
|
b.ref_cnt == 0 for b in manager.free_block_queue.get_all_free_blocks()
|
|
])
|
|
assert len([b
|
|
for b in manager.free_block_queue.get_all_free_blocks()]) == 5
|
|
|
|
manager.free(req2)
|
|
|
|
# Cache miss and eviction.
|
|
req3 = make_request("3", [99] * (16 * 9))
|
|
computed_blocks = manager.get_computed_blocks(req3)
|
|
assert not computed_blocks
|
|
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks)
|
|
# This block ID order also checks the eviction order.
|
|
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
|
|
assert manager.free_block_queue.num_free_blocks == 0
|
|
assert manager.free_block_queue.free_list_head is None
|
|
assert manager.free_block_queue.free_list_tail is None
|
|
|
|
|
|
def test_decode():
|
|
manager = KVCacheManager(
|
|
block_size=16,
|
|
num_gpu_blocks=10,
|
|
sliding_window=False,
|
|
enable_caching=True,
|
|
num_preallocate_tokens=16,
|
|
)
|
|
|
|
# Complete 3 blocks (48 tokens)
|
|
common_token_ids = [i for i in range(3) for _ in range(16)]
|
|
|
|
# Fully cache miss
|
|
# Incomplete 1 block (7 tokens)
|
|
unique_token_ids = [3] * 7
|
|
req0 = make_request("0", common_token_ids + unique_token_ids)
|
|
computed_blocks = manager.get_computed_blocks(req0)
|
|
assert not computed_blocks
|
|
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
|
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
|
|
|
# Append slots without allocating a new block.
|
|
req0.num_computed_tokens = 55
|
|
for _ in range(4):
|
|
req0.append_output_token_ids(8)
|
|
new_blocks = manager.append_slots(req0, 4)
|
|
assert new_blocks is not None and len(new_blocks) == 0
|
|
assert len(manager.block_pool[3].token_ids) == 11
|
|
|
|
# Append slots without allocating a new block, but start using the
|
|
# preallocated block.
|
|
req0.num_computed_tokens = 59
|
|
# 6 tokens to fill the previous block, and 10 tokens to fill
|
|
# the preallocated block.
|
|
for _ in range(5 + 10):
|
|
req0.append_output_token_ids(7)
|
|
new_blocks = manager.append_slots(req0, 15)
|
|
assert new_blocks is not None and len(new_blocks) == 0
|
|
assert len(manager.block_pool[3].token_ids) == 16
|
|
assert len(manager.block_pool[4].token_ids) == 10
|
|
|
|
# Append slots with allocating a new block.
|
|
req0.num_computed_tokens = 74
|
|
# 6 tokens to fill the previous block, and 10 tokens to fill
|
|
# the preallocated block.
|
|
for _ in range(6 + 11):
|
|
req0.append_output_token_ids(12)
|
|
new_blocks = manager.append_slots(req0, 17)
|
|
# Plus one preallocated block.
|
|
assert new_blocks is not None and len(new_blocks) == 2
|
|
assert len(manager.block_pool[4].token_ids) == 16
|
|
assert len(manager.block_pool[5].token_ids) == 11
|
|
assert len(manager.block_pool[6].token_ids) == 0
|
|
|
|
|
|
def test_evict():
|
|
manager = KVCacheManager(
|
|
block_size=16,
|
|
num_gpu_blocks=10,
|
|
sliding_window=False,
|
|
enable_caching=True,
|
|
num_preallocate_tokens=16,
|
|
)
|
|
|
|
last_token_id = 5 * 16 + 7
|
|
req0 = make_request("0", list(range(last_token_id)))
|
|
computed_blocks = manager.get_computed_blocks(req0)
|
|
assert not computed_blocks
|
|
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
|
|
assert len(blocks) == 7 # 5 full + 1 partial + 1 preallocated
|
|
|
|
# 3 blocks.
|
|
req1 = make_request("1", list(range(last_token_id,
|
|
last_token_id + 3 * 16)))
|
|
computed_blocks = manager.get_computed_blocks(req1)
|
|
assert not computed_blocks
|
|
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
|
|
assert len(blocks) == 3 # 3 full blocks
|
|
last_token_id += 3 * 16
|
|
|
|
assert manager.free_block_queue.num_free_blocks == 0
|
|
|
|
manager.free(req0)
|
|
manager.free(req1)
|
|
assert manager.free_block_queue.num_free_blocks == 10
|
|
assert [
|
|
b.block_id for b in manager.free_block_queue.get_all_free_blocks()
|
|
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
|
|
|
|
# Touch the first 2 blocks.
|
|
req2 = make_request("2", list(range(2 * 16 + 3)))
|
|
computed_blocks = manager.get_computed_blocks(req2)
|
|
assert [b.block_id for b in computed_blocks] == [0, 1]
|
|
blocks = manager.allocate_slots(req2, 3, computed_blocks)
|
|
assert [b.block_id for b in blocks] == [6, 5]
|
|
assert manager.free_block_queue.num_free_blocks == 6
|