[V1] Implement sliding window attention in kv_cache_manager (#14097)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
c7e63aa4d8
commit
3a5f0afcd2
@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
|
||||
check_answers(indices, answer, test_texts)
|
||||
|
||||
|
||||
def prep_prompts(batch_size: int):
|
||||
def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
|
||||
"""
|
||||
Generate prompts which a bunch of assignments,
|
||||
then asking for the value of one of them.
|
||||
The prompt is just under 10k tokens; sliding window is 4k
|
||||
so the answer is outside sliding window, but should still be correct.
|
||||
|
||||
Args:
|
||||
batch_size: number of prompts to generate
|
||||
ln_range: an argument to control the length of the prompt
|
||||
"""
|
||||
prompts: list[str] = []
|
||||
answer: list[int] = []
|
||||
@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
|
||||
indices.append(idx)
|
||||
prompt = "```python\n# We set a number of variables, " + \
|
||||
f"x{idx} will be important later\n"
|
||||
ln = random.randint(800, 1100)
|
||||
ln = random.randint(*ln_range)
|
||||
for k in range(30, ln):
|
||||
v = random.randint(10, 99)
|
||||
if k == idx:
|
||||
@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
|
||||
return prompts, answer, indices
|
||||
|
||||
|
||||
def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
|
||||
def check_answers(indices: list[int],
|
||||
answer: list[int],
|
||||
outputs: list[str],
|
||||
accept_rate: float = 0.7):
|
||||
answer2 = [int(text[0:2].strip()) for text in outputs]
|
||||
print(list(zip(indices, zip(answer, answer2))))
|
||||
numok = 0
|
||||
@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
|
||||
numok += 1
|
||||
frac_ok = numok / len(answer)
|
||||
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
|
||||
assert frac_ok > 0.7
|
||||
assert frac_ok >= accept_rate
|
||||
|
||||
|
||||
def check_window(prompts: list[str]):
|
||||
|
@ -4,6 +4,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -12,6 +13,8 @@ from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
hash_block_tokens)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
@ -39,13 +42,23 @@ def make_request(request_id,
|
||||
)
|
||||
|
||||
|
||||
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
tensors={},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
False))
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
make_kv_cache_config(16, 11),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
caching_hash_algo=hash_algo,
|
||||
num_preallocate_tokens=16,
|
||||
@ -67,12 +80,12 @@ def test_prefill(hash_algo):
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
for block_id in (1, 2, 3):
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
@ -80,7 +93,7 @@ def test_prefill(hash_algo):
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
# Check partial/preallocated block metadata
|
||||
for block_id in (3, 4):
|
||||
for block_id in (4, 5):
|
||||
assert manager.block_pool.blocks[block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
@ -90,11 +103,11 @@ def test_prefill(hash_algo):
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [5, 6]
|
||||
assert [b.block_id for b in blocks] == [6, 7]
|
||||
for block in computed_blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -107,14 +120,14 @@ def test_prefill(hash_algo):
|
||||
# All blocks should be available.
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
||||
# The order should be
|
||||
# [unallocated (7, 8, 9)]
|
||||
# [unique_req0 (4, 3)]
|
||||
# [unique_req1 (6, 5)]
|
||||
# [common (2, 1, 0)]
|
||||
# [unallocated (8, 9, 10)]
|
||||
# [unique_req0 (5, 4)]
|
||||
# [unique_req1 (7, 6)]
|
||||
# [common (3, 2, 1)]
|
||||
assert [
|
||||
b.block_id
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
|
||||
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
|
||||
|
||||
# Cache hit in the common prefix when the original block is already free.
|
||||
# Incomplete 1 block (6 tokens)
|
||||
@ -122,11 +135,11 @@ def test_prefill(hash_algo):
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [7, 8]
|
||||
assert [b.block_id for b in blocks] == [8, 9]
|
||||
|
||||
# Although we only have 5 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
@ -148,7 +161,7 @@ def test_prefill(hash_algo):
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
|
||||
assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.free_list_head is None
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
@ -162,10 +175,8 @@ def test_prefill_plp():
|
||||
3. Schedule plp request; no hit should occur; validate blocks
|
||||
'''
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
make_kv_cache_config(16, 11),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
@ -186,13 +197,13 @@ def test_prefill_plp():
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
|
||||
req0_block_hashes = [b.block_hash for b in blocks]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
for block_id in (1, 2, 3):
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
@ -200,7 +211,7 @@ def test_prefill_plp():
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
# Check partial/preallocated block metadata
|
||||
for block_id in (3, 4):
|
||||
for block_id in (4, 5):
|
||||
assert manager.block_pool.blocks[block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
@ -211,11 +222,11 @@ def test_prefill_plp():
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [5, 6]
|
||||
assert [b.block_id for b in blocks] == [6, 7]
|
||||
for block in computed_blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -228,14 +239,14 @@ def test_prefill_plp():
|
||||
# All blocks should be available.
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
||||
# The order should be
|
||||
# [unallocated (7, 8, 9)]
|
||||
# [unique_req0 (4, 3)]
|
||||
# [unique_req1 (6, 5)]
|
||||
# [common (2, 1, 0)]
|
||||
# [unallocated (8, 9, 10)]
|
||||
# [unique_req0 (5, 4)]
|
||||
# [unique_req1 (7, 6)]
|
||||
# [common (3, 2, 1)]
|
||||
assert [
|
||||
b.block_id
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
|
||||
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
|
||||
|
||||
# Request #2 is a prompt-logprobs request:
|
||||
# NO cache hit in the common prefix; duplicates request #0 cached blocks
|
||||
@ -251,7 +262,7 @@ def test_prefill_plp():
|
||||
block_ids = [b.block_id for b in blocks]
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks] == req0_block_hashes
|
||||
assert block_ids != [0, 1, 2, 3, 4]
|
||||
assert block_ids != [1, 2, 3, 4, 5]
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
# Check block reference counts.
|
||||
@ -263,10 +274,8 @@ def test_prefill_plp():
|
||||
|
||||
def test_decode():
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
make_kv_cache_config(16, 11),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
@ -282,7 +291,7 @@ def test_decode():
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
@ -316,10 +325,8 @@ def test_decode():
|
||||
|
||||
def test_evict():
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
make_kv_cache_config(16, 11),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
@ -350,15 +357,15 @@ def test_evict():
|
||||
assert [
|
||||
b.block_id
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
|
||||
] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8]
|
||||
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1]
|
||||
assert [b.block_id for b in computed_blocks] == [1, 2]
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [6, 5]
|
||||
assert [b.block_id for b in blocks] == [7, 6]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 6
|
||||
|
||||
|
||||
@ -369,10 +376,8 @@ def test_hash_block_correct_reuse():
|
||||
"""
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=1,
|
||||
make_kv_cache_config(16, 2),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
@ -408,10 +413,8 @@ def test_computed_blocks_not_evicted():
|
||||
"""
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=2,
|
||||
make_kv_cache_config(block_size, 3),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
@ -424,7 +427,7 @@ def test_computed_blocks_not_evicted():
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 0
|
||||
assert blocks[0].block_id == 1
|
||||
|
||||
# Allocate another block.
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
|
||||
@ -433,7 +436,7 @@ def test_computed_blocks_not_evicted():
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 1
|
||||
assert blocks[0].block_id == 2
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req0)
|
||||
@ -444,13 +447,13 @@ def test_computed_blocks_not_evicted():
|
||||
req2 = make_request("2", list(range(num_tokens * 2)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks) == 1
|
||||
assert computed_blocks[0].block_id == 0
|
||||
assert computed_blocks[0].block_id == 1
|
||||
assert num_computed_tokens == block_size
|
||||
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 1
|
||||
assert blocks[0].block_id == 2
|
||||
|
||||
|
||||
def test_basic_prefix_caching_disabled():
|
||||
@ -459,10 +462,8 @@ def test_basic_prefix_caching_disabled():
|
||||
"""
|
||||
block_size = 4
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=4,
|
||||
make_kv_cache_config(block_size, 5),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=False,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
@ -502,10 +503,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
|
||||
This tests that the preallocated blocks are correctly added.
|
||||
"""
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=10,
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=num_preallocate_tokens,
|
||||
)
|
||||
@ -586,10 +585,8 @@ def test_mm_prefix_caching():
|
||||
This tests that the multi-modal prefix caching is correct.
|
||||
"""
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
make_kv_cache_config(16, 11),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
@ -629,7 +626,7 @@ def test_mm_prefix_caching():
|
||||
assert block_hashes[2].extra_keys == ("bbb", )
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
@ -667,10 +664,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
"""
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=10,
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
@ -723,10 +718,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
|
||||
def test_reset_prefix_cache():
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
make_kv_cache_config(16, 11),
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
@ -736,7 +729,7 @@ def test_reset_prefix_cache():
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
@ -745,7 +738,7 @@ def test_reset_prefix_cache():
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert len(computed_blocks) == 3
|
||||
blocks = manager.allocate_slots(req1, 7, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [4]
|
||||
assert [b.block_id for b in blocks] == [5]
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
|
@ -2,12 +2,15 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
@ -66,12 +69,21 @@ def create_scheduler(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=10000, # A large number of blocks to hold all requests
|
||||
tensors={},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(16, 1, 1, torch.float32, False))
|
||||
],
|
||||
)
|
||||
cache_config.num_gpu_blocks = 10000
|
||||
return Scheduler(
|
||||
scheduler_config,
|
||||
model_config,
|
||||
cache_config,
|
||||
lora_config=None,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
138
tests/v1/core/test_specialized_manager.py
Normal file
138
tests/v1/core/test_specialized_manager.py
Normal file
@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
|
||||
from vllm.v1.core.specialized_manager import SlidingWindowManager
|
||||
from vllm.v1.kv_cache_interface import SlidingWindowSpec
|
||||
|
||||
|
||||
def test_sliding_window_possible_cached_prefix():
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
block_size=2,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
manager = SlidingWindowManager(sliding_window_spec, block_pool)
|
||||
|
||||
def run_one_case(block_is_cached, expect_length):
|
||||
block_hash_list = [
|
||||
BlockHashType(i, ()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
|
||||
# Mock the block pool with the cached blocks
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[block_hash] = {
|
||||
i: block_pool.blocks[i + 10]
|
||||
}
|
||||
|
||||
computed_blocks = manager.find_longest_cache_hit(block_hash_list)
|
||||
assert len(computed_blocks) == expect_length
|
||||
|
||||
assert all(block == block_pool.null_block
|
||||
for block in computed_blocks[:expect_length - 2])
|
||||
for i in range(2):
|
||||
if i < expect_length:
|
||||
block_index = expect_length - i - 1
|
||||
assert computed_blocks[
|
||||
block_index].block_id == block_index + 10
|
||||
|
||||
run_one_case([False] * 10, 0)
|
||||
run_one_case([True], 1)
|
||||
run_one_case([True, False], 1)
|
||||
run_one_case([True, True], 2)
|
||||
run_one_case([True, True, False], 2)
|
||||
run_one_case([True, True, True], 3)
|
||||
run_one_case([True, True, True, False], 3)
|
||||
run_one_case([
|
||||
True, True, False, True, False, False, True, True, False, True, True,
|
||||
True
|
||||
], 12)
|
||||
run_one_case([
|
||||
True, True, False, True, False, False, True, True, False, False, False
|
||||
], 8)
|
||||
run_one_case([
|
||||
True, True, False, True, False, False, True, True, False, False, False,
|
||||
True
|
||||
], 8)
|
||||
|
||||
|
||||
def test_sliding_window_remove_skipped_blocks():
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
block_size=2,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
|
||||
manager = SlidingWindowManager(sliding_window_spec, block_pool)
|
||||
|
||||
null_block_id = block_pool.null_block.block_id
|
||||
|
||||
def id_to_block_table(ids):
|
||||
return [
|
||||
KVCacheBlock(id_)
|
||||
if id_ != null_block_id else block_pool.null_block for id_ in ids
|
||||
]
|
||||
|
||||
def assert_block_id(block_table, ids):
|
||||
for block, id_ in zip(block_table, ids):
|
||||
if id_ == null_block_id:
|
||||
assert block == block_pool.null_block
|
||||
else:
|
||||
assert block.block_id == id_
|
||||
|
||||
original_block_ids = [
|
||||
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
|
||||
]
|
||||
block_table = id_to_block_table(original_block_ids)
|
||||
removed = manager.remove_skipped_blocks(block_table, 0)
|
||||
assert_block_id(removed, [])
|
||||
assert_block_id(block_table, original_block_ids)
|
||||
|
||||
# 4 tokens are computed. Only token 0 is out of the sliding window. As
|
||||
# block 1000 also contains token 1 that is in the sliding window, block 1000
|
||||
# cannot be removed.
|
||||
removed = manager.remove_skipped_blocks(block_table, 4)
|
||||
assert_block_id(removed, [])
|
||||
assert_block_id(block_table, original_block_ids)
|
||||
|
||||
# 5 tokens are computed. Token 0 & 1 are out of the sliding window.
|
||||
# Block 1000 can be removed.
|
||||
removed = manager.remove_skipped_blocks(block_table, 5)
|
||||
assert_block_id(removed, [original_block_ids[0]])
|
||||
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
|
||||
|
||||
# 6 tokens are computed. Token 0-2 are out of the sliding window.
|
||||
# Cannot remove new block as the block 1001 is still used by token 3.
|
||||
removed = manager.remove_skipped_blocks(block_table, 6)
|
||||
assert_block_id(removed, [])
|
||||
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
|
||||
|
||||
# 7 tokens are computed. Token 0-3 are out of the sliding window.
|
||||
# Block 1001 can be removed and block 1000 is already removed.
|
||||
removed = manager.remove_skipped_blocks(block_table, 7)
|
||||
assert_block_id(removed, [original_block_ids[1]])
|
||||
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
|
||||
|
||||
# 11 tokens are computed. Token 0-7 are out of the sliding window.
|
||||
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
|
||||
# sequence, and is expected to be evicted earlier than 1002, so the order
|
||||
# of removed blocks should be [1003, 1002].
|
||||
removed = manager.remove_skipped_blocks(block_table, 11)
|
||||
assert_block_id(removed, [original_block_ids[3], original_block_ids[2]])
|
||||
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])
|
84
tests/v1/e2e/test_correctness_sliding_window.py
Normal file
84
tests/v1/e2e/test_correctness_sliding_window.py
Normal file
@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from ...core.block.e2e.test_correctness_sliding_window import (check_answers,
|
||||
prep_prompts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
sliding_window: int
|
||||
ln_range: tuple[int, int]
|
||||
|
||||
|
||||
model_config = {
|
||||
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)),
|
||||
"google/gemma-2-2b-it": TestConfig(4096, (400, 800)),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"bigcode/starcoder2-3b", # sliding window only
|
||||
"google/gemma-2-2b-it", # sliding window + full attention
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_sliding_window_retrival(monkeypatch, model, batch_size, seed):
|
||||
"""
|
||||
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
|
||||
asks for value of one of them (which is outside the sliding window).
|
||||
If we tell it upfront which we are going to be looking for, then
|
||||
it answers correctly (mostly).
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
test_config = model_config[model]
|
||||
|
||||
llm = LLM(model=model)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size,
|
||||
ln_range=test_config.ln_range)
|
||||
|
||||
check_length(prompts, llm, test_config.sliding_window)
|
||||
|
||||
# Fresh generation
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
check_answers(indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0)
|
||||
|
||||
# Re-generate with the same prompts to test prefix caching
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
check_answers(indices,
|
||||
answer,
|
||||
[response.outputs[0].text for response in responses],
|
||||
accept_rate=1.0)
|
||||
|
||||
|
||||
def check_length(prompts: list[str], llm: LLM, sliding_window: int):
|
||||
"""
|
||||
Check if the prompt length is valid, i.e., longer than the sliding window
|
||||
size and shorter than the model's max length.
|
||||
|
||||
Args:
|
||||
prompts: list of prompts
|
||||
llm: LLM object
|
||||
sliding_window: Sliding window size
|
||||
"""
|
||||
tokenizer = llm.get_tokenizer()
|
||||
max_model_len = llm.llm_engine.model_config.max_model_len
|
||||
assert any(
|
||||
len(tokenizer.encode(prompt)) > sliding_window
|
||||
for prompt in prompts), "Prompt is too short for test"
|
||||
assert all(
|
||||
len(tokenizer.encode(prompt)) <= max_model_len
|
||||
for prompt in prompts), "Prompt is too long for test"
|
@ -1116,8 +1116,7 @@ class CacheConfig:
|
||||
is_attention_free: Whether the model is attention-free.
|
||||
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
|
||||
profiled num_gpu_blocks if specified. Does nothing if None.
|
||||
sliding_window: Sliding window size for the KV cache. Can not work with
|
||||
prefix caching enabled.
|
||||
sliding_window: Sliding window size for the KV cache.
|
||||
enable_prefix_caching: Whether to enable prefix caching.
|
||||
cpu_offload_gb: Size of the CPU offload buffer in GiB.
|
||||
"""
|
||||
|
@ -27,6 +27,7 @@ class BlockPool:
|
||||
"""
|
||||
|
||||
def __init__(self, num_gpu_blocks: int, enable_caching: bool):
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.enable_caching = enable_caching
|
||||
# All kv-cache blocks.
|
||||
@ -50,6 +51,11 @@ class BlockPool:
|
||||
self.cached_block_hash_to_block: dict[BlockHashType, dict[
|
||||
int, KVCacheBlock]] = defaultdict(dict)
|
||||
|
||||
# To represent a placeholder block with block_id=0.
|
||||
# The ref_cnt of null_block is not maintained, needs special care to
|
||||
# avoid freeing it.
|
||||
self.null_block = self.free_block_queue.popleft()
|
||||
|
||||
def get_cached_block(self,
|
||||
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
||||
"""Get a cached block by the block hash, or None if cache miss.
|
||||
@ -214,7 +220,7 @@ class BlockPool:
|
||||
for block in blocks:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0:
|
||||
if block.ref_cnt == 0 and block != self.null_block:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
|
||||
@ -228,7 +234,8 @@ class BlockPool:
|
||||
"""
|
||||
for block in ordered_blocks:
|
||||
block.decr_ref()
|
||||
if block.ref_cnt == 0:
|
||||
# null_block should not be added to the free list.
|
||||
if block.ref_cnt == 0 and block != self.null_block:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
@ -241,10 +248,10 @@ class BlockPool:
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
|
||||
if num_used_blocks > 0:
|
||||
if num_used_blocks != 1: # The null block is always marked as used
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet", num_used_blocks)
|
||||
"blocks (%d) are not freed yet", num_used_blocks - 1)
|
||||
return False
|
||||
|
||||
# Remove all hashes so that no new blocks will hit.
|
||||
|
@ -9,6 +9,8 @@ from vllm.utils import cdiv, sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.core.specialized_manager import get_specialized_manager
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
@ -19,20 +21,22 @@ class KVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = True,
|
||||
caching_hash_algo: str = "builtin",
|
||||
num_preallocate_tokens: int = 64,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1, (
|
||||
"KVCacheManager does not support hybrid models with more than 1 "
|
||||
"kv cache group")
|
||||
kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.num_gpu_blocks = kv_cache_config.num_blocks
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
|
||||
self.sliding_window = sliding_window
|
||||
self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size)
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
@ -48,9 +52,15 @@ class KVCacheManager:
|
||||
# further allocation. When it uses up all the N empty blocks, it gets
|
||||
# N new empty blocks.
|
||||
self.num_preallocate_tokens = num_preallocate_tokens
|
||||
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
|
||||
self.num_preallocate_blocks = cdiv(num_preallocate_tokens,
|
||||
self.block_size)
|
||||
|
||||
self.block_pool = BlockPool(num_gpu_blocks, enable_caching)
|
||||
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching)
|
||||
|
||||
self.specialized_manager = get_specialized_manager(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
)
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
@ -117,17 +127,25 @@ class KVCacheManager:
|
||||
|
||||
self.prefix_cache_stats.requests += 1
|
||||
if request.sampling_params.prompt_logprobs is None:
|
||||
# Check for cache hits
|
||||
computed_blocks = []
|
||||
for block_hash in block_hashes:
|
||||
# block_hashes is a chain of block hashes. If a block hash
|
||||
# is not in the cached_block_hash_to_id, the following
|
||||
# block hashes are not computed yet for sure.
|
||||
if cached_block := self.block_pool.get_cached_block(
|
||||
block_hash):
|
||||
computed_blocks.append(cached_block)
|
||||
else:
|
||||
break
|
||||
if len(block_hashes) * self.block_size == request.num_tokens:
|
||||
# When prompt length is divisible by the block size and all
|
||||
# blocks are cached, we need to recompute the last token. This
|
||||
# have to be achieved by re-computing an entire block because
|
||||
# allocate_slots() assumes num_computed_tokens is always a
|
||||
# multiple of the block size. To achieve this, remove the last
|
||||
# block hash from the block_hashes for find_longest_cache_hit
|
||||
# This limitation can potentially be removed in the future to
|
||||
# slightly improve the performance.
|
||||
last_block_hash = block_hashes.pop()
|
||||
else:
|
||||
last_block_hash = None
|
||||
|
||||
computed_blocks = (
|
||||
self.specialized_manager.find_longest_cache_hit(block_hashes))
|
||||
|
||||
if last_block_hash is not None:
|
||||
# Add back the last block hash if it was removed.
|
||||
block_hashes.append(last_block_hash)
|
||||
|
||||
self.prefix_cache_stats.queries += len(block_hashes)
|
||||
self.prefix_cache_stats.hits += len(computed_blocks)
|
||||
@ -176,13 +194,24 @@ class KVCacheManager:
|
||||
|
||||
new_computed_blocks = new_computed_blocks or []
|
||||
|
||||
req_blocks = self.req_to_blocks[request.request_id]
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
# We can do this even if we cannot schedule this request due to
|
||||
# insufficient free blocks.
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
removed_blocks = self.specialized_manager.remove_skipped_blocks(
|
||||
req_blocks, request.num_computed_tokens)
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = (request.num_computed_tokens +
|
||||
len(new_computed_blocks) * self.block_size)
|
||||
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
|
||||
self.block_size)
|
||||
req_blocks = self.req_to_blocks[request.request_id]
|
||||
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
||||
len(new_computed_blocks))
|
||||
|
||||
|
@ -9,8 +9,9 @@ from typing import Any, Callable, NamedTuple, Optional
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
|
||||
KVCacheSpec, KVCacheTensor)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -483,7 +484,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
needed_memory = 0
|
||||
for layer_spec in kv_cache_spec.values():
|
||||
needed_memory += layer_spec.bytes_for_tokens(max_model_len)
|
||||
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
|
||||
|
||||
if needed_memory > available_memory:
|
||||
raise ValueError(
|
||||
@ -597,6 +598,33 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
"""
|
||||
Only models with one type of KV cache are supported yet. This function tries
|
||||
to convert the KV cache specs to one type if the model is a hybrid model
|
||||
with multiple type of KV cache. It will convert all SlidingWindowSpec to
|
||||
FullAttentionSpec if both types are present.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
"""
|
||||
|
||||
has_full_attention = any(
|
||||
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
|
||||
has_sliding_window = any(
|
||||
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
|
||||
if has_full_attention and has_sliding_window:
|
||||
for layer_name, spec in kv_cache_spec.items():
|
||||
if isinstance(spec, SlidingWindowSpec):
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=spec.block_size,
|
||||
num_kv_heads=spec.num_kv_heads,
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
use_mla=spec.use_mla,
|
||||
)
|
||||
|
||||
|
||||
def get_kv_cache_config(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
@ -613,6 +641,7 @@ def get_kv_cache_config(vllm_config: VllmConfig,
|
||||
The generated KVCacheConfigs
|
||||
"""
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for
|
||||
# most models. Allocate the same amount of memory for
|
||||
|
@ -19,6 +19,7 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
from vllm.v1.core.sched.utils import check_stop
|
||||
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||
EngineCoreOutputs)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@ -35,6 +36,7 @@ class Scheduler(SchedulerInterface):
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
@ -43,6 +45,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.log_stats = log_stats
|
||||
self.structured_output_manager = structured_output_manager
|
||||
|
||||
@ -58,15 +61,11 @@ class Scheduler(SchedulerInterface):
|
||||
self.scheduler_config.max_num_batched_tokens
|
||||
self.max_model_len = self.scheduler_config.max_model_len
|
||||
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
# Create the KV cache manager.
|
||||
self.kv_cache_manager = KVCacheManager(
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
sliding_window=self.cache_config.sliding_window,
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
enable_caching=cache_config.enable_prefix_caching,
|
||||
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
||||
log_stats=self.log_stats)
|
||||
self.block_size = self.cache_config.block_size
|
||||
@ -300,17 +299,6 @@ class Scheduler(SchedulerInterface):
|
||||
# `request.num_prompt_tokens` to consider the resumed requests,
|
||||
# which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if num_new_tokens == 0:
|
||||
# This happens when prompt length is divisible by the block
|
||||
# size and all blocks are cached. Now we force to recompute
|
||||
# the last block. Note that we have to re-compute an entire
|
||||
# block because allocate_slots() assumes num_computed_tokens
|
||||
# is always a multiple of the block size. This limitation
|
||||
# can potentially be removed in the future to slightly
|
||||
# improve the performance.
|
||||
num_computed_tokens -= self.block_size
|
||||
num_new_tokens = self.block_size
|
||||
computed_blocks.pop()
|
||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
||||
num_new_tokens):
|
||||
num_new_tokens = (
|
||||
|
161
vllm/v1/core/specialized_manager.py
Normal file
161
vllm/v1/core/specialized_manager.py
Normal file
@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
SlidingWindowSpec)
|
||||
|
||||
|
||||
class SpecializedManager(ABC):
|
||||
"""
|
||||
An abstract base class for specialized managers that handle the kv
|
||||
cache management logic of different attention layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SpecializedManager.
|
||||
Args:
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
"""
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks. If no cache hit is
|
||||
found, return an empty list.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block.
|
||||
For example, sliding window manager should return a list like
|
||||
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
|
||||
sliding window 8.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||
num_computed_tokens: int) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks`. The removed
|
||||
blocks should be replaced by null_block. Return the removed blocks in
|
||||
eviction order, where the first returned block should be evicted first.
|
||||
Don't free the removed blocks in this function.
|
||||
|
||||
Args:
|
||||
blocks: The list of blocks to be updated.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
Returns:
|
||||
The removed blocks in eviction order.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FullAttentionManager(SpecializedManager):
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
computed_blocks: list[KVCacheBlock] = []
|
||||
for block_hash in block_hashes:
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := self.block_pool.get_cached_block(block_hash):
|
||||
computed_blocks.append(cached_block)
|
||||
else:
|
||||
break
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||
num_computed_tokens: int) -> list[KVCacheBlock]:
|
||||
# No need to remove blocks for full attention.
|
||||
return []
|
||||
|
||||
|
||||
class SlidingWindowManager(SpecializedManager):
|
||||
|
||||
def __init__(self, kv_cache_spec: SlidingWindowSpec,
|
||||
block_pool: BlockPool):
|
||||
super().__init__(kv_cache_spec, block_pool)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
self.sliding_window_contiguous_blocks = cdiv(
|
||||
(kv_cache_spec.sliding_window - 1), self.block_size)
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||
# optimize the time complexity from O(len(block_hashes)) to
|
||||
# O(len(block_hashes) / sliding_window_contiguous_blocks +
|
||||
# sliding_window_contiguous_blocks),
|
||||
# which is good for low cache hit rate scenarios.
|
||||
computed_blocks = [self._null_block] * len(block_hashes)
|
||||
num_contiguous_blocks = 0
|
||||
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(len(block_hashes) - 1, -1, -1):
|
||||
if cached_block := self.block_pool.get_cached_block(
|
||||
block_hashes[i]):
|
||||
computed_blocks[i] = cached_block
|
||||
num_contiguous_blocks += 1
|
||||
if (num_contiguous_blocks
|
||||
>= self.sliding_window_contiguous_blocks):
|
||||
# Trim the trailing blocks.
|
||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||
# when sliding_window_contiguous_blocks=2.
|
||||
del computed_blocks[i + num_contiguous_blocks:]
|
||||
return computed_blocks
|
||||
else:
|
||||
num_contiguous_blocks = 0
|
||||
# The first `num_contiguous_blocks` is a cache hit even if
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
del computed_blocks[num_contiguous_blocks:]
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||
num_computed_tokens: int) -> list[KVCacheBlock]:
|
||||
# Remove the blocks that are no longer be in the sliding window and
|
||||
# skipped during the attention computation.
|
||||
last_useful_token = num_computed_tokens - self.sliding_window + 1
|
||||
last_useful_block = last_useful_token // self.block_size
|
||||
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
for i in range(last_useful_block - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
return removed_blocks
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
}
|
||||
|
||||
|
||||
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool) -> SpecializedManager:
|
||||
manager_class = spec_manager_map[type(kv_cache_spec)]
|
||||
manager = manager_class(kv_cache_spec, block_pool)
|
||||
return manager
|
@ -33,6 +33,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
@ -66,8 +67,9 @@ class EngineCore:
|
||||
self.model_executor = executor_class(vllm_config)
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
|
||||
vllm_config)
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
|
||||
self._initialize_kv_caches(vllm_config)
|
||||
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
@ -95,10 +97,11 @@ class EngineCore:
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
include_finished_set=vllm_config.parallel_config.data_parallel_size
|
||||
> 1,
|
||||
log_stats=self.log_stats,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
)
|
||||
|
||||
# Setup MM Input Mapper.
|
||||
@ -117,8 +120,8 @@ class EngineCore:
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
|
||||
def _initialize_kv_caches(self,
|
||||
vllm_config: VllmConfig) -> tuple[int, int]:
|
||||
def _initialize_kv_caches(
|
||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
@ -143,13 +146,14 @@ class EngineCore:
|
||||
unify_kv_cache_configs(kv_cache_configs)
|
||||
|
||||
# All workers have the same kv_cache_config except layer names, so use
|
||||
# an arbitrary one to get the number of blocks.
|
||||
# an arbitrary one to initialize the scheduler.
|
||||
assert all([
|
||||
cfg.num_blocks == kv_cache_configs[0].num_blocks
|
||||
for cfg in kv_cache_configs
|
||||
])
|
||||
num_gpu_blocks = kv_cache_configs[0].num_blocks
|
||||
num_cpu_blocks = 0
|
||||
scheduler_kv_cache_config = kv_cache_configs[0]
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
@ -157,7 +161,7 @@ class EngineCore:
|
||||
elapsed = time.time() - start
|
||||
logger.info(("init engine (profile, create kv cache, "
|
||||
"warmup model) took %.2f seconds"), elapsed)
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
"""Add request to the scheduler."""
|
||||
|
@ -4,6 +4,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, get_dtype_size
|
||||
|
||||
@ -43,28 +44,23 @@ class KVCacheSpec:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def bytes_for_tokens(self, num_tokens: int) -> int:
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
"""
|
||||
The KV cache size for `num_tokens` tokens in bytes. Returns the real
|
||||
memory size after padding `num_tokens` to full blocks.
|
||||
The maximum possible memory usage of this KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The KV cache size
|
||||
The KV cache size in bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullAttentionSpec(KVCacheSpec):
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
# For MLA we only store a single latent vector
|
||||
@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
def bytes_for_tokens(self, num_tokens: int) -> int:
|
||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||
|
||||
@dataclass
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
sliding_window: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# During chunked prefill, we allocate KV cache for the last
|
||||
# `self.sliding_window-1` computed tokens plus the newly scheduled
|
||||
# tokens. And we won't allocate KV cache for more than `max_model_len`
|
||||
# tokens.
|
||||
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
|
||||
max_model_len)
|
||||
|
||||
# +1 here because the sliding window may not start from the beginning
|
||||
# of the block. For example, if the block size is 4 and num_token
|
||||
# is 4, we need two blocks [XXCD] [EF] to store the sliding
|
||||
# window [CDEF] of 6 tokens.
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -28,8 +28,9 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
check_use_alibi, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
KVCacheConfig, KVCacheSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -1572,7 +1573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
@ -1611,12 +1612,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# cross-attention
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
|
@ -29,7 +29,7 @@ from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
PallasMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
KVCacheSpec, SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput, SamplerOutput)
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
@ -353,17 +353,25 @@ class TPUModelRunner:
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention, MLA.
|
||||
assert isinstance(attn_module, Attention)
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=False,
|
||||
)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
|
Loading…
x
Reference in New Issue
Block a user