[V1] Implement sliding window attention in kv_cache_manager (#14097)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-04-01 15:33:17 +08:00 committed by GitHub
parent c7e63aa4d8
commit 3a5f0afcd2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 662 additions and 158 deletions

View File

@ -129,12 +129,16 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed,
check_answers(indices, answer, test_texts)
def prep_prompts(batch_size: int):
def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct.
Args:
batch_size: number of prompts to generate
ln_range: an argument to control the length of the prompt
"""
prompts: list[str] = []
answer: list[int] = []
@ -145,7 +149,7 @@ def prep_prompts(batch_size: int):
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(800, 1100)
ln = random.randint(*ln_range)
for k in range(30, ln):
v = random.randint(10, 99)
if k == idx:
@ -157,7 +161,10 @@ def prep_prompts(batch_size: int):
return prompts, answer, indices
def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
def check_answers(indices: list[int],
answer: list[int],
outputs: list[str],
accept_rate: float = 0.7):
answer2 = [int(text[0:2].strip()) for text in outputs]
print(list(zip(indices, zip(answer, answer2))))
numok = 0
@ -166,7 +173,7 @@ def check_answers(indices: list[int], answer: list[int], outputs: list[str]):
numok += 1
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok > 0.7
assert frac_ok >= accept_rate
def check_window(prompts: list[str]):

View File

@ -4,6 +4,7 @@
from typing import Optional
import pytest
import torch
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
@ -12,6 +13,8 @@ from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_block_tokens)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
def make_request(request_id,
@ -39,13 +42,23 @@ def make_request(request_id,
)
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
return KVCacheConfig(
num_blocks=num_blocks,
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
def test_prefill(hash_algo):
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
caching_hash_algo=hash_algo,
num_preallocate_tokens=16,
@ -67,12 +80,12 @@ def test_prefill(hash_algo):
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
for block_id in (1, 2, 3):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
@ -80,7 +93,7 @@ def test_prefill(hash_algo):
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (3, 4):
for block_id in (4, 5):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@ -90,11 +103,11 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
assert [b.block_id for b in blocks] == [6, 7]
for block in computed_blocks:
assert block.ref_cnt == 2
@ -107,14 +120,14 @@ def test_prefill(hash_algo):
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
@ -122,11 +135,11 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [7, 8]
assert [b.block_id for b in blocks] == [8, 9]
# Although we only have 5 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@ -148,7 +161,7 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
@ -162,10 +175,8 @@ def test_prefill_plp():
3. Schedule plp request; no hit should occur; validate blocks
'''
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -186,13 +197,13 @@ def test_prefill_plp():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
req0_block_hashes = [b.block_hash for b in blocks]
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
for block_id in (1, 2, 3):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
@ -200,7 +211,7 @@ def test_prefill_plp():
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (3, 4):
for block_id in (4, 5):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
@ -211,11 +222,11 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5, 6]
assert [b.block_id for b in blocks] == [6, 7]
for block in computed_blocks:
assert block.ref_cnt == 2
@ -228,14 +239,14 @@ def test_prefill_plp():
# All blocks should be available.
assert manager.block_pool.free_block_queue.num_free_blocks == 10
# The order should be
# [unallocated (7, 8, 9)]
# [unique_req0 (4, 3)]
# [unique_req1 (6, 5)]
# [common (2, 1, 0)]
# [unallocated (8, 9, 10)]
# [unique_req0 (5, 4)]
# [unique_req1 (7, 6)]
# [common (3, 2, 1)]
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]
# Request #2 is a prompt-logprobs request:
# NO cache hit in the common prefix; duplicates request #0 cached blocks
@ -251,7 +262,7 @@ def test_prefill_plp():
block_ids = [b.block_id for b in blocks]
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes
assert block_ids != [0, 1, 2, 3, 4]
assert block_ids != [1, 2, 3, 4, 5]
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
@ -263,10 +274,8 @@ def test_prefill_plp():
def test_decode():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -282,7 +291,7 @@ def test_decode():
assert not computed_blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@ -316,10 +325,8 @@ def test_decode():
def test_evict():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -350,15 +357,15 @@ def test_evict():
assert [
b.block_id
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
] == [6, 5, 4, 3, 2, 1, 0, 9, 8, 7]
] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [0, 1]
assert [b.block_id for b in computed_blocks] == [1, 2]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert [b.block_id for b in blocks] == [7, 6]
assert manager.block_pool.free_block_queue.num_free_blocks == 6
@ -369,10 +376,8 @@ def test_hash_block_correct_reuse():
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=1,
make_kv_cache_config(16, 2),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -408,10 +413,8 @@ def test_computed_blocks_not_evicted():
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=2,
make_kv_cache_config(block_size, 3),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -424,7 +427,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 0
assert blocks[0].block_id == 1
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
@ -433,7 +436,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
assert blocks[0].block_id == 2
# Free the blocks.
manager.free(req0)
@ -444,13 +447,13 @@ def test_computed_blocks_not_evicted():
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0
assert computed_blocks[0].block_id == 1
assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
assert blocks[0].block_id == 2
def test_basic_prefix_caching_disabled():
@ -459,10 +462,8 @@ def test_basic_prefix_caching_disabled():
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=4,
make_kv_cache_config(block_size, 5),
max_model_len=8192,
sliding_window=None,
enable_caching=False,
num_preallocate_tokens=0,
)
@ -502,10 +503,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
make_kv_cache_config(block_size, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
@ -586,10 +585,8 @@ def test_mm_prefix_caching():
This tests that the multi-modal prefix caching is correct.
"""
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=16,
)
@ -629,7 +626,7 @@ def test_mm_prefix_caching():
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@ -667,10 +664,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
make_kv_cache_config(block_size, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -723,10 +718,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
def test_reset_prefix_cache():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
make_kv_cache_config(16, 11),
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
@ -736,7 +729,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
@ -745,7 +738,7 @@ def test_reset_prefix_cache():
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [4]
assert [b.block_id for b in blocks] == [5]
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()

View File

@ -2,12 +2,15 @@
from typing import Optional
import pytest
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
@ -66,12 +69,21 @@ def create_scheduler(
model_config=model_config,
cache_config=cache_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(16, 1, 1, torch.float32, False))
],
)
cache_config.num_gpu_blocks = 10000
return Scheduler(
scheduler_config,
model_config,
cache_config,
lora_config=None,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)

View File

@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
from vllm.v1.core.specialized_manager import SlidingWindowManager
from vllm.v1.kv_cache_interface import SlidingWindowSpec
def test_sliding_window_possible_cached_prefix():
sliding_window_spec = SlidingWindowSpec(
block_size=2,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec, block_pool)
def run_one_case(block_is_cached, expect_length):
block_hash_list = [
BlockHashType(i, ()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
# Mock the block pool with the cached blocks
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[block_hash] = {
i: block_pool.blocks[i + 10]
}
computed_blocks = manager.find_longest_cache_hit(block_hash_list)
assert len(computed_blocks) == expect_length
assert all(block == block_pool.null_block
for block in computed_blocks[:expect_length - 2])
for i in range(2):
if i < expect_length:
block_index = expect_length - i - 1
assert computed_blocks[
block_index].block_id == block_index + 10
run_one_case([False] * 10, 0)
run_one_case([True], 1)
run_one_case([True, False], 1)
run_one_case([True, True], 2)
run_one_case([True, True, False], 2)
run_one_case([True, True, True], 3)
run_one_case([True, True, True, False], 3)
run_one_case([
True, True, False, True, False, False, True, True, False, True, True,
True
], 12)
run_one_case([
True, True, False, True, False, False, True, True, False, False, False
], 8)
run_one_case([
True, True, False, True, False, False, True, True, False, False, False,
True
], 8)
def test_sliding_window_remove_skipped_blocks():
sliding_window_spec = SlidingWindowSpec(
block_size=2,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
manager = SlidingWindowManager(sliding_window_spec, block_pool)
null_block_id = block_pool.null_block.block_id
def id_to_block_table(ids):
return [
KVCacheBlock(id_)
if id_ != null_block_id else block_pool.null_block for id_ in ids
]
def assert_block_id(block_table, ids):
for block, id_ in zip(block_table, ids):
if id_ == null_block_id:
assert block == block_pool.null_block
else:
assert block.block_id == id_
original_block_ids = [
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
]
block_table = id_to_block_table(original_block_ids)
removed = manager.remove_skipped_blocks(block_table, 0)
assert_block_id(removed, [])
assert_block_id(block_table, original_block_ids)
# 4 tokens are computed. Only token 0 is out of the sliding window. As
# block 1000 also contains token 1 that is in the sliding window, block 1000
# cannot be removed.
removed = manager.remove_skipped_blocks(block_table, 4)
assert_block_id(removed, [])
assert_block_id(block_table, original_block_ids)
# 5 tokens are computed. Token 0 & 1 are out of the sliding window.
# Block 1000 can be removed.
removed = manager.remove_skipped_blocks(block_table, 5)
assert_block_id(removed, [original_block_ids[0]])
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
# 6 tokens are computed. Token 0-2 are out of the sliding window.
# Cannot remove new block as the block 1001 is still used by token 3.
removed = manager.remove_skipped_blocks(block_table, 6)
assert_block_id(removed, [])
assert_block_id(block_table, [null_block_id] + original_block_ids[1:])
# 7 tokens are computed. Token 0-3 are out of the sliding window.
# Block 1001 can be removed and block 1000 is already removed.
removed = manager.remove_skipped_blocks(block_table, 7)
assert_block_id(removed, [original_block_ids[1]])
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
# 11 tokens are computed. Token 0-7 are out of the sliding window.
# Block 1002 & 1003 can be removed now. Block 1003 represents a longer
# sequence, and is expected to be evicted earlier than 1002, so the order
# of removed blocks should be [1003, 1002].
removed = manager.remove_skipped_blocks(block_table, 11)
assert_block_id(removed, [original_block_ids[3], original_block_ids[2]])
assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:])

View File

@ -0,0 +1,84 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import pytest
from vllm import LLM, SamplingParams
from ...core.block.e2e.test_correctness_sliding_window import (check_answers,
prep_prompts)
@dataclass
class TestConfig:
sliding_window: int
ln_range: tuple[int, int]
model_config = {
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)),
"google/gemma-2-2b-it": TestConfig(4096, (400, 800)),
}
@pytest.mark.parametrize(
"model",
[
"bigcode/starcoder2-3b", # sliding window only
"google/gemma-2-2b-it", # sliding window + full attention
])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrival(monkeypatch, model, batch_size, seed):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly).
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
test_config = model_config[model]
llm = LLM(model=model)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
prompts, answer, indices = prep_prompts(batch_size,
ln_range=test_config.ln_range)
check_length(prompts, llm, test_config.sliding_window)
# Fresh generation
responses = llm.generate(prompts, sampling_params)
check_answers(indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0)
# Re-generate with the same prompts to test prefix caching
responses = llm.generate(prompts, sampling_params)
check_answers(indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0)
def check_length(prompts: list[str], llm: LLM, sliding_window: int):
"""
Check if the prompt length is valid, i.e., longer than the sliding window
size and shorter than the model's max length.
Args:
prompts: list of prompts
llm: LLM object
sliding_window: Sliding window size
"""
tokenizer = llm.get_tokenizer()
max_model_len = llm.llm_engine.model_config.max_model_len
assert any(
len(tokenizer.encode(prompt)) > sliding_window
for prompt in prompts), "Prompt is too short for test"
assert all(
len(tokenizer.encode(prompt)) <= max_model_len
for prompt in prompts), "Prompt is too long for test"

View File

@ -1116,8 +1116,7 @@ class CacheConfig:
is_attention_free: Whether the model is attention-free.
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
sliding_window: Sliding window size for the KV cache. Can not work with
prefix caching enabled.
sliding_window: Sliding window size for the KV cache.
enable_prefix_caching: Whether to enable prefix caching.
cpu_offload_gb: Size of the CPU offload buffer in GiB.
"""

View File

@ -27,6 +27,7 @@ class BlockPool:
"""
def __init__(self, num_gpu_blocks: int, enable_caching: bool):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
# All kv-cache blocks.
@ -50,6 +51,11 @@ class BlockPool:
self.cached_block_hash_to_block: dict[BlockHashType, dict[
int, KVCacheBlock]] = defaultdict(dict)
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()
def get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
@ -214,7 +220,7 @@ class BlockPool:
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0:
if block.ref_cnt == 0 and block != self.null_block:
self.free_block_queue.remove(block)
block.incr_ref()
@ -228,7 +234,8 @@ class BlockPool:
"""
for block in ordered_blocks:
block.decr_ref()
if block.ref_cnt == 0:
# null_block should not be added to the free list.
if block.ref_cnt == 0 and block != self.null_block:
self.free_block_queue.append(block)
def reset_prefix_cache(self) -> bool:
@ -241,10 +248,10 @@ class BlockPool:
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
if num_used_blocks > 0:
if num_used_blocks != 1: # The null block is always marked as used
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
"blocks (%d) are not freed yet", num_used_blocks - 1)
return False
# Remove all hashes so that no new blocks will hit.

View File

@ -9,6 +9,8 @@ from vllm.utils import cdiv, sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_request_tokens)
from vllm.v1.core.specialized_manager import get_specialized_manager
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
@ -19,20 +21,22 @@ class KVCacheManager:
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
kv_cache_config: KVCacheConfig,
max_model_len: int,
sliding_window: Optional[int] = None,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
num_preallocate_tokens: int = 64,
log_stats: bool = False,
) -> None:
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
assert len(kv_cache_config.kv_cache_groups) == 1, (
"KVCacheManager does not support hybrid models with more than 1 "
"kv cache group")
kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = kv_cache_spec.block_size
self.num_gpu_blocks = kv_cache_config.num_blocks
self.max_model_len = max_model_len
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
self.sliding_window = sliding_window
self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size)
self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
# FIXME: make prefix cache stats conditional on log_stats
@ -48,9 +52,15 @@ class KVCacheManager:
# further allocation. When it uses up all the N empty blocks, it gets
# N new empty blocks.
self.num_preallocate_tokens = num_preallocate_tokens
self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
self.num_preallocate_blocks = cdiv(num_preallocate_tokens,
self.block_size)
self.block_pool = BlockPool(num_gpu_blocks, enable_caching)
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching)
self.specialized_manager = get_specialized_manager(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
)
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
@ -117,17 +127,25 @@ class KVCacheManager:
self.prefix_cache_stats.requests += 1
if request.sampling_params.prompt_logprobs is None:
# Check for cache hits
computed_blocks = []
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash
# is not in the cached_block_hash_to_id, the following
# block hashes are not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(
block_hash):
computed_blocks.append(cached_block)
else:
break
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
if last_block_hash is not None:
# Add back the last block hash if it was removed.
block_hashes.append(last_block_hash)
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
@ -176,13 +194,24 @@ class KVCacheManager:
new_computed_blocks = new_computed_blocks or []
req_blocks = self.req_to_blocks[request.request_id]
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
removed_blocks = self.specialized_manager.remove_skipped_blocks(
req_blocks, request.num_computed_tokens)
self.block_pool.free_blocks(removed_blocks)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
req_blocks = self.req_to_blocks[request.request_id]
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))

View File

@ -9,8 +9,9 @@ from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, KVCacheTensor)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
@ -483,7 +484,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
max_model_len = vllm_config.model_config.max_model_len
needed_memory = 0
for layer_spec in kv_cache_spec.values():
needed_memory += layer_spec.bytes_for_tokens(max_model_len)
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
if needed_memory > available_memory:
raise ValueError(
@ -597,6 +598,33 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return kv_cache_config
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
Only models with one type of KV cache are supported yet. This function tries
to convert the KV cache specs to one type if the model is a hybrid model
with multiple type of KV cache. It will convert all SlidingWindowSpec to
FullAttentionSpec if both types are present.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
has_full_attention = any(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
if has_full_attention and has_sliding_window:
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
)
def get_kv_cache_config(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
@ -613,6 +641,7 @@ def get_kv_cache_config(vllm_config: VllmConfig,
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
unify_hybrid_kv_cache_specs(kv_cache_spec)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for

View File

@ -19,6 +19,7 @@ from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
@ -35,6 +36,7 @@ class Scheduler(SchedulerInterface):
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
@ -43,6 +45,7 @@ class Scheduler(SchedulerInterface):
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
self.kv_cache_config = kv_cache_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
@ -58,15 +61,11 @@ class Scheduler(SchedulerInterface):
self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
num_gpu_blocks = cache_config.num_gpu_blocks
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching,
enable_caching=cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
log_stats=self.log_stats)
self.block_size = self.cache_config.block_size
@ -300,17 +299,6 @@ class Scheduler(SchedulerInterface):
# `request.num_prompt_tokens` to consider the resumed requests,
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# This happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens
# is always a multiple of the block size. This limitation
# can potentially be removed in the future to slightly
# improve the performance.
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (

View File

@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
class SpecializedManager(ABC):
"""
An abstract base class for specialized managers that handle the kv
cache management logic of different attention layers.
"""
def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
) -> None:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
"""
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
Args:
block_hashes: The block hashes of the request.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
"""
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function.
Args:
blocks: The list of blocks to be updated.
num_computed_tokens: The number of tokens that have been computed.
Returns:
The removed blocks in eviction order.
"""
raise NotImplementedError
class FullAttentionManager(SpecializedManager):
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# No need to remove blocks for full attention.
return []
class SlidingWindowManager(SpecializedManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec,
block_pool: BlockPool):
super().__init__(kv_cache_spec, block_pool)
self.sliding_window = kv_cache_spec.sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self.sliding_window_contiguous_blocks = cdiv(
(kv_cache_spec.sliding_window - 1), self.block_size)
self._null_block = block_pool.null_block
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks = [self._null_block] * len(block_hashes)
num_contiguous_blocks = 0
# Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
block_hashes[i]):
computed_blocks[i] = cached_block
num_contiguous_blocks += 1
if (num_contiguous_blocks
>= self.sliding_window_contiguous_blocks):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del computed_blocks[i + num_contiguous_blocks:]
return computed_blocks
else:
num_contiguous_blocks = 0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
return removed_blocks
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
}
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
block_pool: BlockPool) -> SpecializedManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, block_pool)
return manager

View File

@ -33,6 +33,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
@ -66,8 +67,9 @@ class EngineCore:
self.model_executor = executor_class(vllm_config)
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
vllm_config)
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
self._initialize_kv_caches(vllm_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
@ -95,10 +97,11 @@ class EngineCore:
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager,
)
# Setup MM Input Mapper.
@ -117,8 +120,8 @@ class EngineCore:
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
def _initialize_kv_caches(self,
vllm_config: VllmConfig) -> tuple[int, int]:
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
# Get all kv cache needed by the model
@ -143,13 +146,14 @@ class EngineCore:
unify_kv_cache_configs(kv_cache_configs)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to get the number of blocks.
# an arbitrary one to initialize the scheduler.
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0
scheduler_kv_cache_config = kv_cache_configs[0]
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)
@ -157,7 +161,7 @@ class EngineCore:
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, get_dtype_size
@ -43,28 +44,23 @@ class KVCacheSpec:
"""
raise NotImplementedError
def bytes_for_tokens(self, num_tokens: int) -> int:
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
"""
The KV cache size for `num_tokens` tokens in bytes. Returns the real
memory size after padding `num_tokens` to full blocks.
The maximum possible memory usage of this KV cache in bytes.
Returns:
The KV cache size
The KV cache size in bytes
"""
raise NotImplementedError
@dataclass
class FullAttentionSpec(KVCacheSpec):
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
use_mla: bool
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
@property
def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector
@ -72,8 +68,47 @@ class FullAttentionSpec(KVCacheSpec):
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype)
def bytes_for_tokens(self, num_tokens: int) -> int:
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
@dataclass
class FullAttentionSpec(AttentionSpec):
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@dataclass
class SlidingWindowSpec(AttentionSpec):
sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
@property
def type_id(self) -> str:
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
# During chunked prefill, we allocate KV cache for the last
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(self.sliding_window - 1 + max_num_batched_tokens,
max_model_len)
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass

View File

@ -28,8 +28,9 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
check_use_alibi, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
@ -1572,7 +1573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, FullAttentionSpec):
if isinstance(kv_cache_spec, AttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
@ -1611,12 +1612,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# cross-attention
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.

View File

@ -29,7 +29,7 @@ from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
KVCacheSpec, SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
@ -353,17 +353,25 @@ class TPUModelRunner:
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention, MLA.
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
use_mla=False,
)
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
sliding_window=attn_module.sliding_window,
use_mla=False,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
use_mla=False,
)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.