[Core][Bugfix]Refactor block manager for better testability (#3492)

This commit is contained in:
Cade Daniel 2024-03-27 23:59:28 -07:00 committed by GitHub
parent 8267b06c30
commit 14ccd94c89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 3285 additions and 77 deletions

View File

View File

@ -0,0 +1,56 @@
import contextlib
import gc
import pytest
import ray
import torch
from vllm import LLM
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel)
from vllm.model_executor.utils import set_random_seed
def cleanup():
destroy_model_parallel()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
gc.collect()
torch.cuda.empty_cache()
ray.shutdown()
@pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, seed)
@pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed)
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
distinct_llm_kwargs, seed):
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**distinct_llm_kwargs,
}
def generator_inner():
llm = LLM(**kwargs)
set_random_seed(seed)
yield llm
del llm
cleanup()
for llm in generator_inner():
yield llm
del llm

View File

@ -0,0 +1,86 @@
from itertools import cycle
import pytest
from vllm import SamplingParams
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16,
"forced_num_gpu_blocks": 5 * (64 + 1),
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify block manager v2 produces same outputs as block manager v1, even
when there is preemption.
This constructs two LLM, each with limited number of GPU blocks. The limit
is decided such that as the sequences in the batch grow, sequences must be
preempted and removed from cache.
If the output token ids are equivalent, then we have confidence that the KV
cache is not corrupted in the v2 block manager.
NOTE: We want a significant number of generated tokens so that any incorrect
KV mapping has time to build up error.
"""
output_len = 1024
temperature = 0.0
# We want to ensure equality even with preemption.
# We force the total block size to be 1 + cdiv(output_len, block_size)
# so that only one sequence can fit at a time (once the sequences grow).
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
del llm
return token_ids

View File

@ -0,0 +1,50 @@
import pytest
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
from vllm.core.interfaces import AllocStatus
from ..utils import create_seq_group
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
@pytest.mark.parametrize("watermark", [0.0, 0.5])
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
num_gpu_blocks: int, watermark: float):
block_manager = BlockSpaceManagerV2(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
watermark=watermark,
)
num_watermark_blocks = int(watermark * num_gpu_blocks)
num_output_blocks_per_seq = 1
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
# the current implementation assumes all seqs are new prompts / don't have
# different output lens.
num_output_blocks = num_output_blocks_per_seq
for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
seq_group = create_seq_group(
seq_prompt_lens=block_size * num_prompt_blocks,
seq_output_lens=[
block_size * num_output_blocks_per_seq
for _ in range(num_seqs_per_group)
],
)
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
can_allocate_result = block_manager.can_allocate(seq_group)
num_required_blocks = num_prompt_blocks + num_output_blocks
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
assert can_allocate_result == AllocStatus.NEVER
elif num_gpu_blocks >= num_required_blocks:
assert can_allocate_result == AllocStatus.OK
else:
assert can_allocate_result == AllocStatus.LATER

View File

@ -0,0 +1,500 @@
import pytest
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.utils import Device, cdiv, chunk_list
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
def test_allocate_naive(block_size: int, sequence_len: int):
"""Test the allocation of blocks using the naive allocator.
This test creates a CpuGpuBlockAllocator with the specified block size and
number of blocks. It then allocates multiple BlockTables with varying
sequence lengths and verifies that the number of free blocks decreases as
expected after each allocation.
"""
assert block_size > 1
num_gpu_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type="naive",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
block_size=block_size,
)
token_ids = list(range(sequence_len))
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
block_tables = []
for i in range(5):
assert allocator.get_num_free_blocks(
device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc
block_tables.append(
BlockTable(
block_size=block_size,
block_allocator=allocator,
))
block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU)
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
def test_allocate_prefix_caching(block_size: int, sequence_len: int):
"""Test the allocation of blocks using the prefix caching allocator.
This test creates a CpuGpuBlockAllocator with the specified block size and
number of blocks, using the prefix caching allocator. It then allocates
multiple BlockTables with varying sequence lengths and verifies that the
number of free blocks decreases as expected after each allocation.
The test expects all sequences to share allocations, except for their last
block, which may be mutable. It calculates the expected number of immutable
and mutable blocks per allocation based on the sequence length and block
size.
"""
assert block_size > 1
num_gpu_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
block_size=block_size,
)
token_ids = list(range(sequence_len))
chunked_tokens = list(chunk_list(token_ids, block_size))
num_mutable_blocks_per_alloc = 0 if len(
chunked_tokens[-1]) == block_size else 1
num_immutable_blocks_per_alloc = len(
chunked_tokens) - num_mutable_blocks_per_alloc
block_tables = []
for alloc_i in range(1, 6):
block_tables.append(
BlockTable(
block_size=block_size,
block_allocator=allocator,
))
block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU)
# Expect all sequences to share allocations, except for their last block
# (which may be mutable).
assert allocator.get_num_free_blocks(
device=Device.GPU) == num_gpu_blocks - (
num_immutable_blocks_per_alloc + num_mutable_blocks_per_alloc *
(alloc_i))
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
@pytest.mark.parametrize("device", ["cpu", "gpu"])
def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str,
device: str):
"""Test the allocation and freeing of blocks using different allocators and
devices.
This test creates a CpuGpuBlockAllocator with the specified block size,
number of blocks, allocator type, and device. It then allocates a BlockTable
multiple times with the same sequence and verifies that the number of free
blocks remains consistent after each allocation and freeing.
"""
device = Device[device.upper()]
num_device_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_device_blocks,
num_cpu_blocks=num_device_blocks,
block_size=block_size,
)
token_ids = list(range(sequence_len))
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)
for i in range(5):
block_table.allocate(token_ids=token_ids, device=device)
assert allocator.get_num_free_blocks(
device) == num_device_blocks - num_blocks_per_alloc
assert all(block_id is not None
for block_id in block_table.physical_block_ids)
block_table.free()
assert allocator.get_num_free_blocks(device) == num_device_blocks
@pytest.mark.parametrize("block_size", [1, 8])
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
@pytest.mark.parametrize("append_len", [1, 16, 129])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_append_token_ids_allocation(block_size: int, sequence_len: int,
append_len: int, allocator_type: str):
"""Test the allocation behavior when appending token IDs to a BlockTable.
This test creates a CpuGpuBlockAllocator with the specified block size,
number of blocks, and allocator type. It then allocates a BlockTable with an
initial sequence and appends additional token IDs to it. The test verifies
that the number of allocated blocks before and after appending matches the
expected values.
"""
num_gpu_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
block_size=block_size,
)
token_ids = list(range(sequence_len))
token_ids_to_append = list(range(append_len))
block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)
num_expected_blocks_before_append = len(
list(chunk_list(token_ids, block_size)))
num_expected_appended_blocks = len(
list(chunk_list(token_ids + token_ids_to_append,
block_size))) - num_expected_blocks_before_append
block_table.allocate(token_ids=token_ids, device=Device.GPU)
assert len(
block_table.physical_block_ids) == num_expected_blocks_before_append
block_table.append_token_ids(token_ids_to_append)
assert len(
block_table.physical_block_ids
) == num_expected_blocks_before_append + num_expected_appended_blocks
@pytest.mark.parametrize("block_size", [1, 8])
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
@pytest.mark.parametrize("num_empty_slots", [1, 16, 129])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int,
num_empty_slots: int,
allocator_type: str):
"""Test the allocation behavior when ensuring a certain number of empty
slots in a BlockTable.
This test creates a CpuGpuBlockAllocator with the specified block size,
number of blocks, and allocator type. It then allocates a BlockTable with an
initial sequence and ensures a certain number of empty slots. The test
verifies that the number of allocated blocks before and after ensuring empty
slots matches the expected values. It also checks that filling up the empty
slots does not consume additional blocks.
"""
num_gpu_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
block_size=block_size,
)
token_ids = list(range(sequence_len))
block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)
num_expected_blocks_before_append = len(
list(chunk_list(token_ids, block_size)))
num_expected_appended_blocks = len(
list(chunk_list(token_ids + [-1] * num_empty_slots,
block_size))) - num_expected_blocks_before_append
block_table.allocate(token_ids=token_ids, device=Device.GPU)
# Assert that the empty slots consume the expected number of additional
# blocks.
assert len(
block_table.physical_block_ids) == num_expected_blocks_before_append
block_table.ensure_num_empty_slots(num_empty_slots)
assert len(
block_table.physical_block_ids
) == num_expected_blocks_before_append + num_expected_appended_blocks
# Now, ensure no additional blocks consumed as we fill up the empty slots.
num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU)
block_table.append_token_ids(token_ids=list(range(num_empty_slots)))
assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU)
@pytest.mark.parametrize("block_size", [1, 8])
@pytest.mark.parametrize("sequence_len", [1, 9])
@pytest.mark.parametrize("append_len", [1, 16, 129])
@pytest.mark.parametrize("append_size", [1, 4, 129])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_append_token_ids_correct_content(block_size: int, sequence_len: int,
append_len: int, allocator_type: str,
append_size: int):
"""Verify token ids are correctly appended. Appends various amounts of
token ids in various append sizes, and verifies the final sequence is
correct.
"""
num_gpu_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=1024,
block_size=block_size,
)
token_ids = list(range(sequence_len))
token_ids_to_append = list(range(append_len))
block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)
block_table.allocate(token_ids=token_ids, device=Device.GPU)
appended_so_far = []
for append in chunk_list(token_ids_to_append, append_size):
block_table.append_token_ids(append)
appended_so_far.extend(append)
assert block_table._get_all_token_ids() == token_ids + appended_so_far
assert block_table._get_all_token_ids() == token_ids + token_ids_to_append
@pytest.mark.parametrize("seq_len", [1, 9, 129])
@pytest.mark.parametrize("block_size", [1, 8])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_fork(seq_len: int, block_size: int, allocator_type: str):
"""Create a sequence using the specified allocator.
1. Assert that after forking the sequence, the free block count is the
same.
2. Assert that the forked sequence has the same physical mappings.
3. Then free the original sequence; verify that the free block count is
the same.
4. Finally, free the forked sequence and verify that the free block
count drops to zero.
"""
num_gpu_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
block_size=block_size,
)
token_ids = list(range(seq_len))
block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)
block_table.allocate(token_ids)
num_free_blocks_before_fork = allocator.get_num_free_blocks(
device=Device.GPU)
forked_block_table = block_table.fork()
# Expect physical_block_ids and token_ids to match.
assert (block_table.physical_block_ids ==
forked_block_table.physical_block_ids)
assert block_table._get_all_token_ids(
) == forked_block_table._get_all_token_ids()
# Do not expect any additional allocations.
assert allocator.get_num_free_blocks(
device=Device.GPU) == num_free_blocks_before_fork
# Free the original blocks. Assert num free blocks does not change, since
# refcount is nonzero.
block_table.free()
assert allocator.get_num_free_blocks(
device=Device.GPU) == num_free_blocks_before_fork
# Expect the forked block table to be unaffected by the free.
assert all(block_id is not None
for block_id in forked_block_table.physical_block_ids)
# Free the forked blocks. Assert num free blocks does change, since
# refcount is now zero.
forked_block_table.free()
assert allocator.get_num_free_blocks(device=Device.GPU) == num_gpu_blocks
@pytest.mark.parametrize("block_size", [8])
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
@pytest.mark.parametrize("append_len", [1, 16, 129])
@pytest.mark.parametrize("appender", ["forked", "original"])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_cow(block_size: int, sequence_len: int, append_len: int,
allocator_type: str, appender: str):
"""Fork a sequence; append to the forked sequence; verify there's a CoW.
"""
num_gpu_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
block_size=block_size,
)
token_ids = list(range(sequence_len))
token_ids_to_append = list(range(append_len))
original_block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)
num_expected_non_cow_blocks = cdiv(sequence_len, block_size)
num_expected_cow_blocks = cdiv(sequence_len + append_len,
block_size) - (sequence_len // block_size)
original_block_table.allocate(token_ids=token_ids, device=Device.GPU)
original_block_ids = original_block_table.physical_block_ids
forked_block_table = original_block_table.fork()
# Expect no additional allocation (copy on _write_).
assert allocator.get_num_free_blocks(
Device.GPU) == (num_gpu_blocks - num_expected_non_cow_blocks)
if appender == "forked":
appender_block_table = forked_block_table
static_block_table = original_block_table
elif appender == "original":
appender_block_table = original_block_table
static_block_table = forked_block_table
else:
raise ValueError(f"unknown test config {appender=}")
# Write tokens.
appender_block_table.append_token_ids(token_ids_to_append)
# Expect the non-appending block table to have no change.
assert static_block_table.physical_block_ids == original_block_ids
assert appender_block_table.physical_block_ids != original_block_ids
# Expect the blocks changed during append to have a CoW.
assert allocator.get_num_free_blocks(
Device.GPU) == num_gpu_blocks - (num_expected_non_cow_blocks +
num_expected_cow_blocks)
cows = allocator.clear_copy_on_writes()
if sequence_len % block_size > 0:
# If the last block in the sequence is not full, then when appending we
# expect a CoW.
assert cows
cow_block_id = sequence_len // block_size
expected_src = static_block_table.physical_block_ids[cow_block_id]
expected_dst = appender_block_table.physical_block_ids[cow_block_id]
assert expected_src in cows
assert expected_dst in cows[expected_src]
else:
# Otherwise, there should be no copy-on-write.
assert not cows
static_block_table.free()
appender_block_table.free()
# After free, expect all blocks to be freed.
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
@pytest.mark.parametrize("block_size", [8])
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
@pytest.mark.parametrize("append_len", [1, 16, 129])
@pytest.mark.parametrize("lookahead_slots", [1, 16, 129])
@pytest.mark.parametrize("appender", ["forked", "original"])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_cow_lookahead_simple(block_size: int, sequence_len: int,
append_len: int, lookahead_slots: int,
allocator_type: str, appender: str):
"""Similar to test_cow, except with lookahead allocation. The assertions are
less rigorous due to the complexity of the property under test.
"""
num_gpu_blocks = 1024
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=0,
block_size=block_size,
)
token_ids = list(range(sequence_len))
token_ids_to_append = list(range(append_len))
original_block_table = BlockTable(
block_size=block_size,
block_allocator=allocator,
)
original_block_table.allocate(token_ids=token_ids, device=Device.GPU)
# Allocate lookahead slots.
original_block_table.ensure_num_empty_slots(lookahead_slots)
original_block_ids = original_block_table.physical_block_ids
forked_block_table = original_block_table.fork()
if appender == "forked":
appender_block_table = forked_block_table
static_block_table = original_block_table
elif appender == "original":
appender_block_table = original_block_table
static_block_table = forked_block_table
else:
raise ValueError(f"unknown test config {appender=}")
# Write tokens.
appender_block_table.append_token_ids(token_ids_to_append)
# Expect the non-appending block table to have no change.
assert static_block_table.physical_block_ids == original_block_ids
assert appender_block_table.physical_block_ids != original_block_ids
cows = allocator.clear_copy_on_writes()
# Always expect copy-on-write
assert cows
if sequence_len % block_size > 0:
# If the last block in the sequence is not full, then when appending we
# expect a CoW.
assert cows
cow_block_id = sequence_len // block_size
expected_src = static_block_table.physical_block_ids[cow_block_id]
expected_dst = appender_block_table.physical_block_ids[cow_block_id]
assert expected_src in cows
assert expected_dst in cows[expected_src]
static_block_table.free()
appender_block_table.free()
# After free, expect all blocks to be freed.
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks

View File

@ -0,0 +1,42 @@
import random
import pytest
from vllm.core.block.common import RefCounter
@pytest.mark.parametrize("seed", list(range(20)))
@pytest.mark.parametrize("num_incrs", [1, 100])
@pytest.mark.parametrize("num_blocks", [1024])
def test_incr(seed: int, num_incrs: int, num_blocks: int):
random.seed(seed)
all_block_indices = list(range(num_blocks))
counter = RefCounter(all_block_indices=all_block_indices)
block_id = random.randint(0, num_blocks - 1)
for i in range(num_incrs):
value = counter.incr(block_id)
assert value == i + 1
@pytest.mark.parametrize("seed", list(range(20)))
@pytest.mark.parametrize("num_incrs", [1, 100])
@pytest.mark.parametrize("num_blocks", [1024])
def test_incr_decr(seed: int, num_incrs: int, num_blocks: int):
random.seed(seed)
all_block_indices = list(range(num_blocks))
counter = RefCounter(all_block_indices=all_block_indices)
block_id = random.randint(0, num_blocks - 1)
for i in range(num_incrs):
value = counter.incr(block_id)
assert value == i + 1
for i in range(num_incrs):
value = counter.decr(block_id)
assert value == num_incrs - (i + 1)
with pytest.raises(AssertionError):
counter.decr(block_id)

View File

@ -0,0 +1,93 @@
import pytest
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.utils import Device, chunk_list
@pytest.mark.parametrize("num_cpu_blocks", [0, 512])
@pytest.mark.parametrize("num_gpu_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int,
block_size: int, allocator_type: str):
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
cpu_blocks = [
allocator.allocate_mutable(prev_block=None, device=Device.CPU)
for _ in range(num_cpu_blocks)
]
assert allocator.get_num_free_blocks(Device.CPU) == 0
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
gpu_blocks = [
allocator.allocate_mutable(prev_block=None, device=Device.GPU)
for _ in range(num_gpu_blocks)
]
assert allocator.get_num_free_blocks(Device.CPU) == 0
assert allocator.get_num_free_blocks(Device.GPU) == 0
_ = [allocator.free(block) for block in cpu_blocks]
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == 0
_ = [allocator.free(block) for block in gpu_blocks]
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
@pytest.mark.parametrize("num_cpu_blocks", [0, 512])
@pytest.mark.parametrize("num_gpu_blocks", [1024])
@pytest.mark.parametrize("block_size", [2])
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int,
block_size: int, allocator_type: str):
allocator = CpuGpuBlockAllocator.create(
allocator_type=allocator_type,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)
unique_token_ids = list(
range((num_cpu_blocks + num_gpu_blocks) * block_size))
gpu_token_ids = chunk_list(unique_token_ids[:num_gpu_blocks * block_size],
block_size)
cpu_token_ids = chunk_list(unique_token_ids[num_gpu_blocks * block_size:],
block_size)
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
cpu_blocks = [
allocator.allocate_immutable(prev_block=None,
token_ids=token_ids,
device=Device.CPU)
for token_ids in cpu_token_ids
]
assert allocator.get_num_free_blocks(Device.CPU) == 0
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
gpu_blocks = [
allocator.allocate_immutable(prev_block=None,
token_ids=token_ids,
device=Device.GPU)
for token_ids in gpu_token_ids
]
assert allocator.get_num_free_blocks(Device.CPU) == 0
assert allocator.get_num_free_blocks(Device.GPU) == 0
_ = [allocator.free(block) for block in cpu_blocks]
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == 0
_ = [allocator.free(block) for block in gpu_blocks]
assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks

View File

@ -0,0 +1,102 @@
from typing import List, Optional
import pytest
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
class TestNaiveBlockAllocator:
@staticmethod
def create_allocate_lambda(allocate_type: str,
allocator: NaiveBlockAllocator,
prev_block: Optional[Block],
token_ids: List[int]):
if allocate_type == "immutable":
allocate_block = lambda: allocator.allocate_immutable(
prev_block=prev_block, token_ids=token_ids)
elif allocate_type == "mutable":
allocate_block = lambda: allocator.allocate_mutable(prev_block=
prev_block)
else:
raise ValueError()
return allocate_block
@staticmethod
@pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
@pytest.mark.parametrize("num_blocks", [1, 1024])
@pytest.mark.parametrize("block_size", [1, 16])
def test_allocate_ooms(allocate_type: str, num_blocks: int,
block_size: int):
allocator = NaiveBlockAllocator(create_block=NaiveBlock,
num_blocks=num_blocks,
block_size=block_size)
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
allocate_type,
allocator,
prev_block=None,
token_ids=list(range(block_size)))
[allocate_block() for _ in range(num_blocks)]
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocate_block()
@staticmethod
@pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
@pytest.mark.parametrize("num_blocks", [1, 1024])
@pytest.mark.parametrize("block_size", [1, 16])
def test_free_prevents_oom(allocate_type: str, num_blocks: int,
block_size: int):
allocator = NaiveBlockAllocator(create_block=NaiveBlock,
num_blocks=num_blocks,
block_size=block_size)
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
allocate_type,
allocator,
prev_block=None,
token_ids=list(range(block_size)))
blocks = [allocate_block() for _ in range(num_blocks)]
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocate_block()
block_to_free = blocks.pop()
for _ in range(100):
block_id = block_to_free.block_id
allocator.free(block_to_free)
assert block_to_free.block_id is None
new_block = allocate_block()
assert new_block.block_id == block_id
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocate_block()
block_to_free = new_block
@staticmethod
@pytest.mark.parametrize("allocate_type", ["immutable", "mutable"])
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])
def test_get_num_free_blocks(allocate_type: str, num_blocks: int,
block_size: int):
allocator = NaiveBlockAllocator(create_block=NaiveBlock,
num_blocks=num_blocks,
block_size=block_size)
allocate_block = TestNaiveBlockAllocator.create_allocate_lambda(
allocate_type,
allocator,
prev_block=None,
token_ids=list(range(block_size)))
assert allocator.get_num_free_blocks() == num_blocks
blocks = [allocate_block() for _ in range(num_blocks)]
for i, block in enumerate(blocks):
assert allocator.get_num_free_blocks() == i
allocator.free(block)

View File

@ -0,0 +1,384 @@
import math
import random
from typing import List, Optional
from unittest.mock import MagicMock
import pytest
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.prefix_caching_block import (PrefixCachingBlock,
PrefixCachingBlockAllocator)
class TestPrefixCachingBlock:
@staticmethod
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("block_size", [1, 16])
@pytest.mark.parametrize("is_curr_block_full", [True, False])
def test_first_block_has_correct_content_hash(seed: int, block_size: int,
is_curr_block_full: bool):
"""Verify a block which is first in the sequence has the correct hash.
"""
random.seed(seed)
num_to_fill = block_size if is_curr_block_full else random.randint(
0, block_size - 1)
token_ids = list(range(num_to_fill))
mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator)
block_with_prev = PrefixCachingBlock(
prev_block=None,
token_ids=token_ids,
block_size=block_size,
prefix_caching_allocator=mock_allocator)
if is_curr_block_full:
# Expect hash since block is full.
assert block_with_prev.content_hash == (
PrefixCachingBlock.hash_block_tokens(
is_first_block=True,
prev_block_hash=None,
cur_block_token_ids=token_ids))
else:
# Do not expect hash since block is not full.
assert block_with_prev.content_hash is None
@staticmethod
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize("block_size", [1, 16])
@pytest.mark.parametrize("is_curr_block_full", [True, False])
@pytest.mark.parametrize("prev_block_has_hash", [True, False])
def test_nth_block_has_correct_content_hash(seed: int, block_size: int,
is_curr_block_full: bool,
prev_block_has_hash: bool):
"""Verify a block which is not first in the sequence has the correct
hash.
"""
random.seed(seed)
previous_block = MagicMock(spec=PrefixCachingBlock)
prev_block_hash = random.randint(0, 1000)
previous_block.content_hash = (prev_block_hash
if prev_block_has_hash else None)
num_to_fill = block_size if is_curr_block_full else random.randint(
0, block_size - 1)
token_ids = list(range(num_to_fill))
mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator)
block_with_prev = PrefixCachingBlock(
prev_block=previous_block,
token_ids=token_ids,
block_size=block_size,
prefix_caching_allocator=mock_allocator,
)
if is_curr_block_full and prev_block_has_hash:
# Expect hash since block is full and previous block has hash.
assert (block_with_prev.content_hash ==
PrefixCachingBlock.hash_block_tokens(
is_first_block=False,
prev_block_hash=prev_block_hash,
cur_block_token_ids=token_ids))
else:
# Do not expect hash since block is not full or the previous block
# does not have a hash.
assert block_with_prev.content_hash is None
@staticmethod
@pytest.mark.parametrize("block_size", [1, 2, 16])
@pytest.mark.parametrize("num_tokens", list(range(3)))
@pytest.mark.parametrize("num_empty_trailing_blocks", [0, 1, 10])
def test_blocks_have_correct_hash_in_chain(block_size: int,
num_tokens: int,
num_empty_trailing_blocks: int):
"""Create two chains of logical blocks with the same contents.
Assert the hashes are equal.
"""
random.seed(0)
token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)]
first_chain, second_chain = [
TestPrefixCachingBlock.create_chain(
block_size=block_size,
token_ids=token_ids,
num_empty_trailing_blocks=num_empty_trailing_blocks)
for _ in range(2)
]
for first_chain_block, second_chain_block in zip(
first_chain, second_chain):
assert (first_chain_block.content_hash ==
second_chain_block.content_hash)
if not first_chain or not second_chain:
assert first_chain == second_chain
assert num_tokens == 0
@staticmethod
def create_chain(block_size: int,
token_ids: List[int],
num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]:
"""Helper method which creates a chain of blocks.
"""
blocks = []
num_blocks = math.ceil(
len(token_ids) / block_size) + num_empty_trailing_blocks
if num_blocks == 0:
return []
allocator = MagicMock(spec=PrefixCachingBlockAllocator)
prev_block = None
for block_number in range(0, num_blocks):
prev_block = PrefixCachingBlock(
prev_block=prev_block,
token_ids=[],
block_size=block_size,
prefix_caching_allocator=allocator,
)
tokens_to_append = token_ids[block_number *
block_size:(block_number + 1) *
block_size]
if tokens_to_append:
prev_block.append_token_ids(tokens_to_append)
blocks.append(prev_block)
return blocks
class TestPrefixCachingBlockAllocator:
@staticmethod
def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator,
prev_block: Optional[Block],
token_ids: List[int]):
if allocate_type == "immutable":
allocate_block = lambda: allocator.allocate_immutable(
prev_block=prev_block, token_ids=token_ids)
elif allocate_type == "mutable":
allocate_block = lambda: allocator.allocate_mutable(prev_block=
prev_block)
else:
raise ValueError()
return allocate_block
@staticmethod
@pytest.mark.parametrize("num_blocks", [1, 1024])
@pytest.mark.parametrize("block_size", [1, 16])
def test_allocate_mutable_ooms(num_blocks: int, block_size: int):
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda(
allocate_type="mutable",
allocator=allocator,
prev_block=None,
token_ids=list(range(block_size)),
)
[allocate_block() for _ in range(num_blocks)]
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocate_block()
@staticmethod
@pytest.mark.parametrize("num_blocks", [1, 1024])
@pytest.mark.parametrize("block_size", [1, 16])
def test_allocate_immutable_does_not_oom_single_hash(
num_blocks: int, block_size: int):
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda(
allocate_type="immutable",
allocator=allocator,
prev_block=None,
token_ids=list(range(block_size)),
)
blocks = [allocate_block() for _ in range(num_blocks)]
# Expect no OOM. If these were mutable blocks, this would OOM.
non_oom_block = allocate_block()
# Expect all blocks to have same physical block index.
for block in blocks:
assert (block.block_id == non_oom_block.block_id)
@staticmethod
@pytest.mark.parametrize("num_blocks", [1, 1024])
@pytest.mark.parametrize("block_size", [1, 16])
def test_allocate_immutable_ooms_many_hash(num_blocks: int,
block_size: int):
"""Consume all blocks using many different hashes/block content.
Do this by creating a sequence that is very long.
Expect next block to OOM.
"""
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
# Create token ids that will exhaust all blocks.
token_ids = list(range(num_blocks * block_size))
chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
# Expect allocation with unseen hash to fail.
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocator.allocate_immutable(prev_block=chain[-1],
token_ids=list(range(block_size)))
# Expect mutable allocation to fail.
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocator.allocate_mutable(prev_block=chain[-1])
# Expect allocation of exact same chain to pass.
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
# Expect physical block indices to be the same in both chains.
assert chain and second_chain
for first_chain_block, second_chain_block in zip(chain, second_chain):
assert (first_chain_block.block_id == second_chain_block.block_id)
@staticmethod
@pytest.mark.parametrize("num_blocks", [1, 1024])
@pytest.mark.parametrize("block_size", [1, 16])
def test_free_prevents_oom(num_blocks: int, block_size: int):
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
# Create token ids that will exhaust all blocks.
token_ids = list(range(num_blocks * block_size))
chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
# Expect mutable allocation to fail.
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocator.allocate_mutable(prev_block=None)
block_to_free = chain[-1]
# Expect free/allocate loop to succeed many times.
for i in range(100):
block_id = block_to_free.block_id
allocator.free(block_to_free)
assert block_to_free.block_id is None, i
new_block = allocator.allocate_mutable(prev_block=None)
assert new_block.block_id == block_id, i
with pytest.raises(BlockAllocator.NoFreeBlocksError):
allocator.allocate_mutable(prev_block=None)
block_to_free = new_block
@staticmethod
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("seed", list(range(20)))
def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int):
random.seed(seed)
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
num_blocks_to_consume = random.randint(1, num_blocks - 1)
# Create token ids that will exhaust all blocks.
token_ids = list(range(num_blocks_to_consume * block_size))
chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
# Free each block in chain, assert num free blocks includes new free
# block.
for i, block in enumerate(chain):
assert allocator.get_num_free_blocks() == (num_blocks -
num_blocks_to_consume +
i)
allocator.free(block)
@staticmethod
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("seed", list(range(20)))
def test_get_num_free_blocks_shared(num_blocks: int, block_size: int,
seed: int):
"""Verify sharing occurs by allocating two sequences that share prefixes
and incrementally freeing blocks.
"""
random.seed(seed)
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
num_blocks_to_consume = random.randint(1, num_blocks - 1)
# Create token ids that will exhaust all blocks.
token_ids = list(range(num_blocks_to_consume * block_size))
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
# Free each block in the first chain. Since all blocks are shared, the
# free count should stay constant.
for i, block in enumerate(first_chain):
assert allocator.get_num_free_blocks() == (num_blocks -
num_blocks_to_consume)
allocator.free(block)
# Free each block in the second chain. Since the refcount is now zero,
# the free count should increment with each free.
for i, block in enumerate(second_chain):
assert allocator.get_num_free_blocks() == (num_blocks -
num_blocks_to_consume +
i)
allocator.free(block)
@staticmethod
def create_immutable_chain(
block_size: int,
token_ids: List[int],
allocator: PrefixCachingBlockAllocator,
) -> List[PrefixCachingBlock]:
"""Helper method which creates a chain of blocks.
"""
blocks = []
num_blocks = math.ceil(len(token_ids) / block_size)
if num_blocks == 0:
return []
prev_block = None
for block_number in range(0, num_blocks):
block_token_ids = token_ids[block_number *
block_size:(block_number + 1) *
block_size]
prev_block = allocator.allocate_immutable(
prev_block=prev_block, token_ids=block_token_ids)
blocks.append(prev_block)
return blocks

View File

@ -5,8 +5,9 @@ import pytest
from vllm import SamplingParams
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager import (AllocStatus, BlockSpaceManager,
UncachedBlockAllocator)
from vllm.core.block_manager_v1 import (BlockSpaceManagerV1,
UncachedBlockAllocator)
from vllm.core.interfaces import AllocStatus
from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
@ -63,10 +64,10 @@ def test_allocate():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate same sequence group to all available gpu blocks.
for i in range(num_gpu_blocks):
@ -77,10 +78,10 @@ def test_allocate():
# Allocate same sequence group to all available gpu blocks.
# Use watermark to reserve one gpu block.
block_manager = BlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=1 / num_gpu_blocks)
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=1 / num_gpu_blocks)
for i in range(num_gpu_blocks - 1):
_, seq_group = create_dummy_prompt(str(i), block_size)
assert block_manager.can_allocate(seq_group)
@ -92,10 +93,10 @@ def test_append_slot_single_seq():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate single seq to gpu block.
prompt, seq_group = create_dummy_prompt("1", block_size)
@ -124,10 +125,10 @@ def test_append_slot_cow():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManager(block_size=block_size,
num_cpu_blocks=num_cpu_blocks,
num_gpu_blocks=num_gpu_blocks,
watermark=0)
block_manager = BlockSpaceManagerV1(block_size=block_size,
num_cpu_blocks=num_cpu_blocks,
num_gpu_blocks=num_gpu_blocks,
watermark=0)
# Allocate prompt to gpu block. There is one slot left in the block.
prompt = Sequence(seq_id=1,
@ -165,10 +166,10 @@ def test_fork():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
prompt, seq_group = create_dummy_prompt("1",
block_size - 1,
@ -192,10 +193,10 @@ def test_swap():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
prompt.status = SequenceStatus.WAITING
@ -238,10 +239,10 @@ def test_free():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
prompt, seq_group = create_dummy_prompt("1", block_size)
block_manager.allocate(seq_group)
@ -262,10 +263,10 @@ def test_reset():
block_size = 4
num_cpu_blocks = 4
num_gpu_blocks = 4
block_manager = BlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0)
# Allocate same seq group on all available gpu blocks.
original_blocks = block_manager.get_num_free_gpu_blocks()
@ -289,11 +290,11 @@ def test_sliding_window_multi_seq():
num_cpu_blocks = 8
num_gpu_blocks = 8
sliding_window = 2
block_manager = BlockSpaceManager(block_size,
num_cpu_blocks,
num_gpu_blocks,
sliding_window=sliding_window,
watermark=0)
block_manager = BlockSpaceManagerV1(block_size,
num_cpu_blocks,
num_gpu_blocks,
sliding_window=sliding_window,
watermark=0)
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks

View File

@ -13,7 +13,7 @@ from .utils import create_dummy_prompt
def test_scheduler_add_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
cache_config.num_cpu_blocks = 4
cache_config.num_gpu_blocks = 4
scheduler = Scheduler(scheduler_config, cache_config, None)

View File

@ -2,7 +2,7 @@ import time
from typing import Tuple
from vllm import SamplingParams
from vllm.sequence import Sequence, SequenceGroup
from vllm.sequence import Logprob, Sequence, SequenceGroup
def create_dummy_prompt(
@ -23,5 +23,42 @@ def create_dummy_prompt(
return prompt, seq_group
def create_seq_group(
seq_prompt_lens=1024,
seq_output_lens=(128, ),
request_id='0',
seq_id_start=0,
) -> SequenceGroup:
assert len(seq_output_lens) > 0
prompt_token_ids = [0] * seq_prompt_lens
seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
prompt="",
prompt_token_ids=prompt_token_ids,
block_size=16,
)
for i in range(output_len):
seq.append_token_id(
token_id=i,
logprobs={i: Logprob(0.0)},
)
seqs.append(seq)
seq_group = SequenceGroup(
request_id=request_id,
seqs=seqs,
sampling_params=SamplingParams(),
arrival_time=time.time(),
)
return seq_group
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size

View File

@ -4,7 +4,7 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest
from vllm.core.block_manager import CachedBlockAllocator
from vllm.core.block_manager_v1 import CachedBlockAllocator
from vllm.utils import Device

View File

@ -324,6 +324,8 @@ class CacheConfig:
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
"""
def __init__(
@ -332,12 +334,14 @@ class CacheConfig:
gpu_memory_utilization: float,
swap_space: int,
cache_dtype: str,
forced_num_gpu_blocks: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.forced_num_gpu_blocks = forced_num_gpu_blocks
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
@ -528,6 +532,7 @@ class SchedulerConfig:
and generated text).
delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
"""
def __init__(
@ -535,6 +540,7 @@ class SchedulerConfig:
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
use_v2_block_manager: bool = False,
delay_factor: float = 0.0,
) -> None:
if max_num_batched_tokens is not None:
@ -546,6 +552,7 @@ class SchedulerConfig:
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.delay_factor = delay_factor
self.use_v2_block_manager = use_v2_block_manager
self._verify_args()
def _verify_args(self) -> None:

View File

@ -0,0 +1,245 @@
from typing import List, Optional
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
from vllm.utils import Device, cdiv, chunk_list
class BlockTable:
"""A class to manage blocks for a specific sequence.
The BlockTable maps a sequence of tokens to a list of blocks, where each
block represents a contiguous memory allocation for a portion of the
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
responsible for allocating and freeing memory for the blocks.
Args:
block_size (int): The maximum number of tokens that can be stored in a
single block.
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]], optional): An optional list of existing
blocks to initialize the BlockTable with. If not provided, an empty
BlockTable is created.
Attributes:
_block_size (int): The maximum number of tokens that can be stored in a
single block.
_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]]): The list of blocks managed by this
BlockTable.
_num_full_slots (int): The number of tokens currently stored in the
blocks.
"""
def __init__(
self,
block_size: int,
block_allocator: DeviceAwareBlockAllocator,
_blocks: Optional[List[Block]] = None,
):
self._block_size = block_size
self._allocator = block_allocator
self._blocks: Optional[List[Block]] = _blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self._num_full_slots = len(self._get_all_token_ids())
@staticmethod
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
"""Calculates the minimum number of blocks required to store a given
sequence of token IDs.
This assumes worst-case scenario, where every block requires a new
allocation (e.g. ignoring prefix caching).
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
block_size (int): The maximum number of tokens that can be stored in
a single block.
Returns:
int: The minimum number of blocks required to store the given
sequence of token IDs.
"""
return cdiv(len(token_ids), block_size)
def allocate(self,
token_ids: List[int],
device: Device = Device.GPU) -> None:
"""Allocates memory blocks for storing the given sequence of token IDs.
This method allocates the required number of blocks to store the given
sequence of token IDs.
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU.
"""
assert not self._is_allocated
assert token_ids
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
self._num_full_slots = len(token_ids)
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
This method appends the given sequence of token IDs to the existing
blocks in the BlockTable. If there is not enough space in the existing
blocks, new blocks are allocated using the `ensure_num_empty_slots`
method to accommodate the additional tokens.
The token IDs are divided into chunks of size `block_size` (except for
the first chunk, which may be smaller), and each chunk is appended to a
separate block.
Args:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
self.ensure_num_empty_slots(num_empty_slots=len(token_ids))
blocks = self._blocks[self._num_full_slots // self._block_size:]
first_chunk_size = self._block_size - (self._num_full_slots %
self._block_size)
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
token_ids[first_chunk_size:], self._block_size)
for block, token_block in zip(blocks, token_blocks):
block.append_token_ids(token_block)
self._num_full_slots += len(token_ids)
def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
"""Ensures that the BlockTable has at least the specified number of
empty slots available.
This method checks if the BlockTable has enough empty slots (i.e.,
available space) to accommodate the requested number of tokens. If not,
it allocates additional blocks on the GPU to ensure that the required
number of empty slots is available.
Args:
num_empty_slots (int): The minimum number of empty slots required.
"""
# Currently the block table only supports
# appending tokens to GPU blocks.
device = Device.GPU
assert self._is_allocated
if self._num_empty_slots >= num_empty_slots:
return
slots_to_allocate = num_empty_slots - self._num_empty_slots
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
for _ in range(blocks_to_allocate):
self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1],
device=device))
def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the
current instance.
This method creates a new BlockTable instance with the same block size,
block allocator, and a copy of the blocks from the current instance. The
new BlockTable has its own independent set of blocks, but shares the
same underlying memory allocation with the original BlockTable.
Returns:
BlockTable: A new BlockTable instance with a copy of the blocks from
the current instance.
"""
assert self._is_allocated
forked_blocks = self._allocator.fork(self._blocks[-1])
return BlockTable(
block_size=self._block_size,
block_allocator=self._allocator,
_blocks=forked_blocks,
)
def free(self) -> None:
"""Frees the memory occupied by the blocks in the BlockTable.
This method iterates over all the blocks in the `_blocks` list and calls
the `free` method of the `_allocator` object to release the memory
occupied by each block. After freeing all the blocks, the `_blocks` list
is set to `None`.
"""
assert self._is_allocated
for block in self._blocks:
self._allocator.free(block)
self._blocks = None
@property
def physical_block_ids(self) -> List[int]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
This property returns a list of integers, where each integer represents
the physical block index of a corresponding block in the `_blocks` list.
The physical block index is a unique identifier for the memory location
occupied by the block.
Returns:
List[int]: A list of physical block indices for the blocks in the
BlockTable.
"""
assert self._is_allocated
return [block.block_id for block in self._blocks]
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> List[Block]:
blocks = []
for block_token_ids in chunk_list(token_ids, self._block_size):
if len(block_token_ids) == self._block_size:
# If the block is full, create an immutable block.
prev_block = self._allocator.allocate_immutable(
prev_block, token_ids=block_token_ids, device=device)
else:
# Else, partially fill a mutable block with token ids.
prev_block = self._allocator.allocate_mutable(
prev_block=prev_block, device=device)
prev_block.append_token_ids(block_token_ids)
blocks.append(prev_block)
return blocks
def _get_all_token_ids(self) -> List[int]:
# NOTE: This function is O(seq_len); use sparingly.
token_ids = []
if not self._is_allocated:
return token_ids
for block in self._blocks:
token_ids.extend(block.token_ids)
return token_ids
@property
def _is_allocated(self) -> bool:
return self._blocks is not None
@property
def _num_empty_slots(self) -> int:
assert self._is_allocated
return len(self._blocks) * self._block_size - self._num_full_slots
@property
def num_full_slots(self) -> int:
"""Returns the total number of tokens currently stored in the
BlockTable.
Returns:
int: The total number of tokens currently stored in the BlockTable.
"""
return self._num_full_slots

185
vllm/core/block/common.py Normal file
View File

@ -0,0 +1,185 @@
from collections import defaultdict
from typing import Dict, Iterable, List, Optional
from vllm.core.block.interfaces import Block, BlockAllocator
BlockId = int
RefCount = int
class RefCounter:
"""A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their
corresponding reference counts. It provides methods to increment, decrement,
and retrieve the reference count for a given block index.
Args:
all_block_indices (Iterable[BlockId]): An iterable of block indices
to initialize the reference counter with.
"""
def __init__(self, all_block_indices: Iterable[BlockId]):
deduped = set(all_block_indices)
self._refcounts: Dict[BlockId,
RefCount] = {index: 0
for index in deduped}
def incr(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
pre_incr_refcount = self._refcounts[block_id]
assert pre_incr_refcount >= 0
post_incr_refcount = pre_incr_refcount + 1
self._refcounts[block_id] = post_incr_refcount
return post_incr_refcount
def decr(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
refcount = self._refcounts[block_id]
assert refcount > 0
refcount -= 1
self._refcounts[block_id] = refcount
return refcount
def get(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
return self._refcounts[block_id]
def as_readonly(self) -> "ReadOnlyRefCounter":
return ReadOnlyRefCounter(self)
class ReadOnlyRefCounter:
"""A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the
reference counts maintained by a RefCounter instance. It does not allow
modifications to the reference counts.
Args:
refcounter (RefCounter): The RefCounter instance to create a read-only
view for.
"""
def __init__(self, refcounter: RefCounter):
self._refcounter = refcounter
def incr(self, block_id: BlockId) -> RefCount:
raise ValueError("Incr not allowed")
def decr(self, block_id: BlockId) -> RefCount:
raise ValueError("Decr not allowed")
def get(self, block_id: BlockId) -> RefCount:
return self._refcounter.get(block_id)
class CopyOnWriteTracker:
"""A class for tracking and managing copy-on-write operations for blocks.
The CopyOnWriteTracker class maintains a mapping of source block indices to
their corresponding copy-on-write destination block indices. It works in
conjunction with a RefCounter and a BlockAllocator to handle reference
counting and block allocation.
Args:
refcounter (RefCounter): The reference counter used to track block
reference counts.
allocator (BlockAllocator): The block allocator used to allocate and
free blocks.
"""
def __init__(
self,
refcounter: RefCounter,
allocator: BlockAllocator,
):
self._copy_on_writes = defaultdict(list)
self._refcounter = refcounter
self._allocator = allocator
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
This method checks the reference count of the given block. If the
reference count is greater than 1, indicating that the block is shared,
a copy-on-write operation is performed. The original block is freed,
and a new block is allocated with the same content. The new block index
is returned.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
block_id = block.block_id
if block_id is None:
return block_id
refcount = self._refcounter.get(block_id)
assert refcount != 0
if refcount > 1:
src_block_id = block_id
# Decrement refcount of the old block.
self._allocator.free(block)
# Allocate a fresh new block.
block_id = self._allocator.allocate_mutable(
prev_block=block.prev_block).block_id
# Track src/dst copy.
self._copy_on_writes[src_block_id].append(block_id)
return block_id
def clear_cows(self) -> Dict[BlockId, List[BlockId]]:
"""Clears the copy-on-write tracking information and returns the current
state.
This method returns a dictionary mapping source block indices to lists
of destination block indices for the current copy-on-write operations.
It then clears the internal tracking information.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices for the
current copy-on-write operations.
"""
cows = dict(self._copy_on_writes)
self._copy_on_writes.clear()
return cows
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.
This function recursively traverses the sequence of blocks in reverse order,
starting from the given last block, and returns a list of all the blocks in
the sequence.
Args:
last_block (Block): The last block in the sequence.
Returns:
List[Block]: A list of all the blocks in the sequence, in the order they
appear.
"""
def recurse(block: Block, lst: List[Block]) -> None:
if block.prev_block is not None:
recurse(block.prev_block, lst)
lst.append(block)
all_blocks = []
recurse(last_block, all_blocks)
return all_blocks

View File

@ -0,0 +1,206 @@
from typing import Dict, List, Optional
from vllm.core.block.interfaces import (Block, BlockAllocator,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
from vllm.utils import Device
class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
"""A block allocator that can allocate blocks on both CPU and GPU memory.
This class implements the `DeviceAwareBlockAllocator` interface and provides
functionality for allocating and managing blocks of memory on both CPU and
GPU devices.
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
blocks, and allows for allocation, deallocation, forking, and swapping of
blocks across these memory pools.
"""
@staticmethod
def create(
allocator_type: str,
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
) -> DeviceAwareBlockAllocator:
"""Creates a CpuGpuBlockAllocator instance with the specified
configuration.
This static method creates and returns a CpuGpuBlockAllocator instance
based on the provided parameters. It initializes the CPU and GPU block
allocators with the specified number of blocks, block size, and
allocator type.
Args:
allocator_type (str): The type of block allocator to use for CPU
and GPU blocks. Currently supported values are "naive" and
"prefix_caching".
num_gpu_blocks (int): The number of blocks to allocate for GPU
memory.
num_cpu_blocks (int): The number of blocks to allocate for CPU
memory.
block_size (int): The size of each block in number of tokens.
Returns:
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
specified configuration.
Notes:
- The block IDs are assigned contiguously, with GPU block IDs coming
before CPU block IDs.
"""
block_ids = list(range(num_gpu_blocks + num_cpu_blocks))
gpu_block_ids = block_ids[:num_gpu_blocks]
cpu_block_ids = block_ids[num_gpu_blocks:]
if allocator_type == "naive":
gpu_allocator = NaiveBlockAllocator(
create_block=NaiveBlock,
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator = NaiveBlockAllocator(
create_block=NaiveBlock,
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
)
elif allocator_type == "prefix_caching":
gpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
)
cpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
)
else:
raise ValueError(f"Unknown allocator type {allocator_type=}")
return CpuGpuBlockAllocator(
cpu_block_allocator=cpu_allocator,
gpu_block_allocator=gpu_allocator,
)
def __init__(
self,
cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator,
):
assert not (
cpu_block_allocator.all_block_ids
& gpu_block_allocator.all_block_ids
), "cpu and gpu block allocators can't have intersection of block ids"
self._allocators = {
Device.CPU: cpu_block_allocator,
Device.GPU: gpu_block_allocator,
}
self._block_ids_to_allocator = {}
for _, allocator in self._allocators.items():
for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
"""Allocates a new mutable block on the specified device.
Args:
prev_block (Optional[Block]): The previous block to in the sequence.
Used for prefix hashing.
device (Device): The device on which to allocate the new block.
Returns:
Block: The newly allocated mutable block.
"""
return self._allocators[device].allocate_mutable(prev_block)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
token_ids (List[int]): The list of token IDs to be stored in the new
block.
device (Device): The device on which to allocate the new block.
Returns:
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return self._allocators[device].allocate_immutable(
prev_block, token_ids)
def free(self, block: Block) -> None:
"""Frees the memory occupied by the given block.
Args:
block (Block): The block to be freed.
"""
allocator = self._block_ids_to_allocator[block.block_id]
return allocator.free(block)
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
"""
allocator = self._block_ids_to_allocator[last_block.block_id]
return allocator.fork(last_block)
def get_num_free_blocks(self, device: Device) -> int:
"""Returns the number of free blocks available on the specified device.
Args:
device (Device): The device for which to query the number of free
blocks.
Returns:
int: The number of free blocks available on the specified device.
"""
return self._allocators[device].get_num_free_blocks()
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
Returns:
Dict[int, List[int]]: A dictionary mapping source block IDs to lists
of destination block IDs.
"""
# CoW only supported on GPU
device = Device.GPU
return self._allocators[device].clear_copy_on_writes()
def mark_blocks_as_computed(self) -> None:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed()
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].get_common_computed_block_ids(
seq_block_ids)
def all_block_ids(self) -> frozenset[int]:
return frozenset(self._block_ids_to_allocator.keys())

View File

@ -0,0 +1,105 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Dict, List, Optional, Protocol
from vllm.utils import Device
class Block(ABC):
@abstractmethod
def append_token_ids(self, token_ids: List[int]) -> None:
pass
@abstractproperty
def block_id(self) -> Optional[int]:
pass
@abstractproperty
def token_ids(self) -> List[int]:
pass
@abstractproperty
def num_empty_slots(self) -> int:
pass
@abstractproperty
def is_full(self) -> bool:
pass
@abstractproperty
def prev_block(self) -> Optional["Block"]:
pass
class Factory(Protocol):
@abstractmethod
def __call__(
self,
prev_block: Optional["Block"],
token_ids: List[int],
block_size: int,
allocator: "BlockAllocator",
block_id: Optional[int] = None,
) -> "Block":
pass
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
pass
@abstractmethod
def free(self, block: Block) -> None:
pass
@abstractmethod
def fork(self, last_block: Block) -> List[Block]:
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
pass
@abstractproperty
def all_block_ids(self) -> frozenset[int]:
pass
@abstractmethod
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass
@abstractmethod
def mark_blocks_as_computed(self) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass
class NoFreeBlocksError(ValueError):
pass
class DeviceAwareBlockAllocator(BlockAllocator):
@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass
@abstractmethod
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block:
pass
@abstractmethod
def get_num_free_blocks(self, device: Device) -> int:
pass

View File

@ -0,0 +1,275 @@
from typing import Dict, Iterable, List, Optional, Set
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator
BlockId = int
Refcount = int
class NaiveBlockAllocator(BlockAllocator):
"""A simple block allocator that manages blocks of memory without prefix
caching.
Args:
create_block (Block.Factory): A factory function for creating new
blocks. This is used when a NaiveBlockAllocator is composed within
a prefix caching allocator -- the naive block allocator must
construct prefix caching blocks (but shouldn't know anything else
about them).
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids (Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
def __init__(
self,
create_block: Block.Factory,
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
):
if block_ids is None:
block_ids = range(num_blocks)
self._free_block_indices: Set[BlockId] = set(block_ids)
self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks
self._refcounter = RefCounter(
all_block_indices=self._free_block_indices)
self._create_block = create_block
self._block_size = block_size
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
token_ids (List[int]): The token IDs to be stored in the new block.
Returns:
Block: The newly allocated immutable block.
"""
block = self.allocate_mutable(prev_block=prev_block)
block.append_token_ids(token_ids)
return block
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
"""Allocates a new mutable block, linked to the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
Returns:
Block: The newly allocated mutable block.
"""
block_id = self._allocate_new_block_id()
return self._create_block(
prev_block=prev_block,
token_ids=[],
block_id=block_id,
block_size=self._block_size,
allocator=self,
)
def free(self, block: Block) -> None:
self._free_block_id(block.block_id)
# Mark the block as having no allocation.
block.block_id = None
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = []
prev_block = None
for block in source_blocks:
# Increment refcount for each block.
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_blocks.append(
self._create_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
prev_block = forked_blocks[-1]
return forked_blocks
def get_num_free_blocks(self) -> int:
return len(self._free_block_indices)
def _allocate_new_block_id(self) -> BlockId:
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
block_id = next(iter(self._free_block_indices))
self._refcounter.incr(block_id)
self._free_block_indices.remove(block_id)
return block_id
def _free_block_id(self, block_id: BlockId) -> None:
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.add(block_id)
@property
def refcounter(self):
return self._refcounter
@property
def all_block_ids(self):
return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices.
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_computed(self) -> None:
"""Mark blocks as computed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
"""Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return
an empty list.
"""
return []
class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix
caching.
The NaiveBlock class represents a block of token IDs with a fixed size. It
provides methods for appending token IDs to the block and manages copy-on
-write operations when necessary.
Args:
prev_block (Block): The previous block in the sequence.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
allocator (BlockAllocator): The block allocator associated with this
block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None, which means no allocation has been
made.
_cow_target (Optional[Block], optional): The copy-on-write target block.
If not provided, it defaults to self.
"""
def __init__(self,
prev_block: Block,
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
_cow_target: Optional[Block] = None):
self._token_ids = []
self._block_size = block_size
self._prev_block = prev_block
self._block_id = block_id
self._allocator = allocator
self._cow_target = _cow_target if _cow_target is not None else self
self._append_token_ids_no_cow(token_ids)
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block, instructing the allocator
to perform a copy-on-write if necessary.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
self._append_token_ids_no_cow(token_ids)
if self._block_id is not None:
self._block_id = (self._allocator.cow_block_if_not_appendable(
self._cow_target))
def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
assert self.num_empty_slots >= len(token_ids)
self._token_ids.extend(token_ids)
@property
def block_id(self) -> Optional[int]:
return self._block_id
@block_id.setter
def block_id(self, value: Optional[int]) -> None:
self._block_id = value
@property
def is_full(self) -> bool:
return self.num_empty_slots == 0
@property
def num_empty_slots(self) -> int:
return self._block_size - len(self._token_ids)
@property
def token_ids(self) -> List[int]:
return self._token_ids
def block_size(self) -> int:
return self._block_size
@property
def prev_block(self) -> Optional["Block"]:
return self._prev_block

View File

@ -0,0 +1,472 @@
"""Token blocks."""
from itertools import takewhile
from os.path import commonprefix
from typing import Dict, Iterable, List, Optional
from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
PrefixHash = int
BlockId = int
class PrefixCachingBlockAllocator(BlockAllocator):
"""A block allocator that implements prefix caching.
The PrefixCachingBlockAllocator maintains a cache of blocks based on their
content hash. It reuses blocks with the same content hash to avoid redundant
memory allocation. The allocator also supports copy-on-write operations.
Args:
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
# TODO last access time / evictor integration
def __init__(
self,
num_blocks: int,
block_size: int,
block_ids: Optional[Iterable[int]] = None,
):
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
# of self._cached_blocks.
self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
# An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator(
create_block=self._create_block,
num_blocks=num_blocks,
block_size=block_size,
block_ids=block_ids,
)
self._block_size = block_size
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
# blocks.
self._refcounter = self._hashless_allocator.refcounter
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly(),
allocator=self,
)
# Implements Block.Factory.
def _create_block(
self,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
) -> Block:
# Bind block to self.
allocator = self
return PrefixCachingBlock(
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
prefix_caching_allocator=allocator,
)
def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
token_ids (List[int]): The token IDs to be stored in the block.
Returns:
Block: The allocated immutable block.
"""
assert_prefix_caching_block_or_none(prev_block)
block = self._create_block(
prev_block=prev_block,
token_ids=token_ids,
block_size=self._block_size,
allocator=self,
)
assert block.content_hash is not None
cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
block.block_id = cached_block_id
self._incr_refcount_cached_block(block.content_hash,
block.block_id)
return block
block = self.allocate_mutable(prev_block)
block.append_token_ids(token_ids)
assert block.content_hash is not None
# TODO computed bit
return block
def allocate_mutable(self, prev_block: Block) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
Args:
prev_block (Block): The previous block in the sequence.
Returns:
Block: The allocated mutable block.
"""
assert_prefix_caching_block_or_none(prev_block)
try:
return self._hashless_allocator.allocate_mutable(
prev_block=prev_block)
except BlockAllocator.NoFreeBlocksError:
# We must check the unused cached blocks before raising OOM.
pass
if self._unused_cached_blocks:
# TODO policy for selecting block to remove
content_hash_to_evict = next(iter(self._unused_cached_blocks))
# Clear content hash mapping; the block will be overwritten.
del self._cached_blocks[content_hash_to_evict]
block_id = self._unused_cached_blocks.pop(content_hash_to_evict)
refcount = self._refcounter.incr(block_id)
assert refcount == 1
block = self._create_block(
prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
allocator=self,
block_id=block_id,
)
assert block.content_hash is None
return block
# No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError()
def _incr_refcount_cached_block(self, content_hash: int,
block_id: BlockId) -> None:
refcount = self._refcounter.incr(block_id)
if refcount == 1:
assert content_hash in self._unused_cached_blocks
del self._unused_cached_blocks[content_hash]
def free(self, block: Block) -> None:
"""Decrement the refcount of the block. If the decremented refcount is
zero, store the block in the freelist.
If the block has a content hash (meaning it is immutable), then we will
keep the block around in case future allocations require it.
"""
assert (block.block_id
is not None), "freeing unallocated block is undefined"
self._free_block_id_for_block(block.block_id, block)
block.block_id = None
def _free_block_id_for_block(self, block_id: BlockId,
block: Block) -> None:
assert isinstance(block, PrefixCachingBlock)
if block.content_hash is None:
return self._hashless_allocator.free(block)
refcount = self._refcounter.decr(block_id)
# If no longer used, add the block to the unused cached blocks.
if refcount == 0:
assert block.content_hash not in self._unused_cached_blocks
assert block.content_hash in self._cached_blocks
self._unused_cached_blocks[block.content_hash] = block_id
def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = []
prev_block = None
for block in source_blocks:
refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block"
forked_blocks.append(
self._create_block(
prev_block=prev_block,
token_ids=block.token_ids,
block_id=block.block_id,
block_size=self._block_size,
allocator=self,
))
prev_block = forked_blocks[-1]
return forked_blocks
def get_num_free_blocks(self) -> int:
# The number of free blocks is the number of hashless free blocks
# plus the number of hashful blocks that are unused.
return self._hashless_allocator.get_num_free_blocks() + len(
self._unused_cached_blocks)
@property
def all_block_ids(self) -> frozenset[int]:
return self._hashless_allocator.all_block_ids
def promote_to_immutable_block(self,
block: "PrefixCachingBlock") -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
having the same prefix.
Note that if we already have a cached block with the same content, we
will replace the newly-promoted block's mapping with the existing cached
block.
Args:
block (PrefixCachingBlock): The mutable block to be promoted.
Returns:
BlockId: Either the original block index, or the block index of
the previously cached block matching the same content.
"""
assert block.content_hash is not None
assert block.block_id is not None
assert self._refcounter.get(block.block_id) > 0
# If the content hash does not have a corresponding cached block,
# set this block as the cached block.
if block.content_hash not in self._cached_blocks:
self._cached_blocks[block.content_hash] = block.block_id
else:
self._free_block_id_for_block(block.block_id, block)
self._incr_refcount_cached_block(
block.content_hash, self._cached_blocks[block.content_hash])
return self._cached_blocks[block.content_hash]
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return self._cow_tracker.cow_block_if_not_appendable(block)
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices.
"""
return self._cow_tracker.clear_cows()
def mark_blocks_as_computed(self) -> None:
"""Mark blocks as computed, used in prefix caching."""
# TODO Track computed blocks.
pass
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
"""
# TODO: Track computed blocks.
computed = lambda block_id: False
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
ids_list = [
takewhile(lambda block_id: computed(block_id), seq[:-1])
for seq in seq_block_ids
]
return commonprefix([ids for ids in ids_list if ids != []])
class PrefixCachingBlock(Block):
"""A block implementation that supports prefix caching.
The PrefixCachingBlock class represents a block of token IDs with prefix
caching capabilities. It wraps a NaiveBlock internally and provides
additional functionality for content hashing and promoting immutable blocks
with the prefix caching allocator.
Args:
prev_block (Optional[PrefixCachingBlock]): The previous block in the
sequence.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
"""
def __init__(
self,
prev_block: Optional["PrefixCachingBlock"],
token_ids: List[int],
block_size: int,
prefix_caching_allocator: PrefixCachingBlockAllocator,
block_id: Optional[int] = None,
):
assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator
self._block = NaiveBlock(
prev_block=prev_block,
token_ids=token_ids,
block_size=block_size,
block_id=block_id,
allocator=prefix_caching_allocator,
_cow_target=self,
)
def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
Internally, the naive block handles CoW.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
assert token_ids
# naive block handles CoW.
self._block.append_token_ids(token_ids)
# If the content hash is present, then the block can be made immutable.
# Register ourselves with the allocator, potentially replacing the
# physical block index.
if self.content_hash is not None:
self.block_id = (self._prefix_caching_allocator.
promote_to_immutable_block(self))
@property
def block_id(self) -> Optional[int]:
return self._block.block_id
@block_id.setter
def block_id(self, value) -> None:
self._block.block_id = value
@property
def is_full(self) -> bool:
return self._block.is_full
@property
def num_empty_slots(self) -> int:
return self._block.num_empty_slots
@property
def block_size(self) -> int:
return self._block.block_size
@property
def token_ids(self) -> List[int]:
return self._block.token_ids
@property
def prev_block(self) -> Optional[Block]:
return self._prev_block
@property
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
not yet defined.
For the content-based hash to be defined, the current block must be
full.
"""
# If the hash is already computed, return it.
if self._cached_content_hash is not None:
return self._cached_content_hash
# We cannot compute a hash for the current block because it is not full.
if not self.is_full:
return None
is_first_block = self._prev_block is None
prev_block_hash = (None if is_first_block else
self._prev_block.content_hash)
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
if prev_block_hash is None and not is_first_block:
return None
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block,
prev_block_hash,
cur_block_token_ids=self.token_ids)
return self._cached_content_hash
@staticmethod
def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
cur_block_token_ids: List[int]) -> int:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching.
NOTE: Content-based hashing does not yet support LoRA.
Parameters:
- is_first_block (bool): A flag indicating if the block is the first in
the sequence.
- prev_block_hash (Optional[int]): The hash of the previous block. None
if this is the first block.
- cur_block_token_ids (List[int]): A list of token ids in the current
block. The current block is assumed to be full.
Returns:
- int: The computed hash value for the block.
"""
assert (prev_block_hash is None) == is_first_block
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
def assert_prefix_caching_block_or_none(block: Optional[Block]):
if block is None:
return
assert isinstance(block, PrefixCachingBlock)

View File

@ -1,5 +1,4 @@
"""A block manager that manages token blocks."""
import enum
from abc import ABC, abstractmethod
from itertools import count, takewhile
from os.path import commonprefix
@ -7,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
@ -196,21 +196,7 @@ class UncachedBlockAllocator(BlockAllocatorBase):
"Invalid codepath for uncached block allocator.")
class AllocStatus(enum.Enum):
"""Result for BlockSpaceManager.can_allocate
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager:
class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks."""
def __init__(
@ -355,6 +341,11 @@ class BlockSpaceManager:
self,
seq: Sequence,
) -> PhysicalTokenBlock:
# Called before a new block is appended.
# This is in charge of allocating a new physical block (to be appended).
# None if the last block is not full. Otherwise, we set it to the
# content hash.
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
@ -362,7 +353,14 @@ class BlockSpaceManager:
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)
# num_hashed_tokens is used to compute future hashes
# (e.g. in the hashing function, it is used to ask the sequence for
# prefix tokens)
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
# If the block has is None, then the block is not full.
# If the block is not full, then we expect it to have a refcount of 1.
if block_hash is None:
assert new_block.ref_count == 1
return new_block
@ -576,16 +574,16 @@ class BlockSpaceManager:
for b in takewhile(lambda b: b.computed, block_table[:-1])
]
def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
"""
# Can return non-empty result only with prefix caching enabled.
if not self.enable_caching:
return []
ids_list = [
self.get_all_computed_blocks(seq)
for seq in iter(seq_group.seqs_dict.values())
]
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])
def mark_blocks_as_computed(self, seq_group: SequenceGroup):

View File

@ -0,0 +1,210 @@
"""A block manager that manages token blocks."""
from typing import Dict, List, Optional, Tuple
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
SeqId = int
class BlockSpaceManagerV2(BlockSpaceManager):
"""BlockSpaceManager which manages the allocation of KV cache.
It owns responsibility for allocation, swapping, allocating memory for
autoregressively-generated tokens, and other advanced features such as
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
The current implementation is partial; in particular prefix caching and
sliding-window are not feature complete. This class implements the design
described in https://github.com/vllm-project/vllm/pull/3492.
Args:
block_size (int): The size of each memory block.
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
watermark (float, optional): The threshold used for memory swapping.
Defaults to 0.01.
sliding_window (Optional[int], optional): The size of the sliding
window. Defaults to None.
enable_caching (bool, optional): Flag indicating whether caching is
enabled. Defaults to False.
"""
def __init__(
self,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
enable_caching: bool = False,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
assert sliding_window is None, "Sliding window not yet supported"
self.block_sliding_window = None
self.watermark = watermark
assert watermark >= 0.0
assert not enable_caching, "Prefix caching not yet supported"
self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks)
self.block_allocator = CpuGpuBlockAllocator.create(
# Currently, only naive blocks are supported (no prefix caching).
allocator_type="naive",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
)
self.block_tables: Dict[SeqId, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = BlockTable.get_num_required_blocks(
seq.get_token_ids(),
block_size=self.block_size,
)
assert self.block_sliding_window is None
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
device=Device.GPU)
# Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks):
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def allocate(self, seq_group: SequenceGroup) -> None:
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
assert not (set(seq.seq_id for seq in waiting_seqs)
& self.block_tables.keys()), "block table already exists"
# NOTE: Here we assume that all sequences in the group have the same
# prompt.
seq = waiting_seqs[0]
block_table = BlockTable(
block_size=self.block_size,
block_allocator=self.block_allocator,
)
assert self.block_sliding_window is None
block_table.allocate(seq.get_token_ids())
self.block_tables[seq.seq_id] = block_table
# Assign the block table for each sequence.
for seq in waiting_seqs[1:]:
self.block_tables[seq.seq_id] = block_table.fork()
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
Device.GPU)
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def append_slot(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
block_table = self.block_tables[seq.seq_id]
# Get unseen token ids.
num_full_slots = block_table.num_full_slots
unseen_token_ids = seq.get_token_ids()[num_full_slots:]
assert unseen_token_ids
block_table.append_token_ids(unseen_token_ids)
# Return any copy-on-writes.
_ = self.block_allocator.clear_copy_on_writes()
# TODO extend append_slot interface to append_slots
# @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250
return None
def free(self, seq: Sequence) -> None:
if seq.seq_id not in self.block_tables:
# Already freed or haven't been scheduled yet.
return
self.block_tables[seq.seq_id].free()
del self.block_tables[seq.seq_id]
def get_block_table(self, seq: Sequence) -> List[int]:
assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids)
return block_ids
def access_all_blocks_in_seq(self, seq, now):
# TODO add prefix caching support.
# Tracked here https://github.com/vllm-project/vllm/issues/3667
pass
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# We ignore the sequence group as its not necessary. After the batch is
# formed by the scheduler, we do not need to mark blocks from individual
# sequence groups as computed -- all blocks in the batch can be marked
# as computed.
self.block_allocator.mark_blocks_as_computed()
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.
Currently, the attention implementation only supports skipping cached
blocks if they are a contiguous prefix of cached blocks.
This method determines which blocks can be safely skipped for all
sequences in the sequence group.
"""
seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
]
return self.block_allocator.get_common_computed_block_ids(
seq_block_ids)
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
return False
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
raise NotImplementedError
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return False
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
raise NotImplementedError
def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU)
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)

107
vllm/core/interfaces.py Normal file
View File

@ -0,0 +1,107 @@
import enum
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from vllm.sequence import Sequence, SequenceGroup
class AllocStatus(enum.Enum):
"""Result for BlockSpaceManager.can_allocate
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager(ABC):
@staticmethod
def get_block_space_manager_class(version: str):
version = version.lower()
if version == "v1":
from vllm.core.block_manager_v1 import BlockSpaceManagerV1
return BlockSpaceManagerV1
if version == "v2":
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
return BlockSpaceManagerV2
raise ValueError(f"Unknown version {version=}")
@abstractmethod
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
pass
@abstractmethod
def allocate(self, seq_group: SequenceGroup) -> None:
pass
@abstractmethod
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def append_slot(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
pass
@abstractmethod
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
pass
@abstractmethod
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
pass
@abstractmethod
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
pass
@abstractmethod
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
pass
@abstractmethod
def free(self, seq: Sequence) -> None:
pass
@abstractmethod
def get_block_table(self, seq: Sequence) -> List[int]:
pass
@abstractmethod
def get_num_free_gpu_blocks(self) -> int:
pass
@abstractmethod
def get_num_free_cpu_blocks(self) -> int:
pass
@abstractmethod
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
pass
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass

View File

@ -4,7 +4,7 @@ from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -88,8 +88,13 @@ class Scheduler:
# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
version="v2" if self.scheduler_config.
use_v2_block_manager else "v1")
# Create the block space manager.
self.block_manager = BlockSpaceManager(
self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks,
@ -378,6 +383,10 @@ class Scheduler:
block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now)
common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=scheduler_outputs.prompt_run,
@ -385,8 +394,7 @@ class Scheduler:
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
lora_request=seq_group.lora_request,
computed_block_nums=self.block_manager.
get_common_computed_block_ids(seq_group),
computed_block_nums=common_computed_block_nums,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
@ -396,6 +404,14 @@ class Scheduler:
if scheduler_outputs.prompt_run else None,
)
seq_group_metadata_list.append(seq_group_metadata)
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
# will crash the vLLM instance / will not retry.
for seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(seq_group)
return seq_group_metadata_list, scheduler_outputs
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
@ -503,9 +519,6 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
self.block_manager.mark_blocks_as_computed(seq_group)
def _passed_delay(self, now: float) -> bool:
if self.prev_prompt:
self.last_prompt_latency = now - self.prev_time

View File

@ -28,6 +28,7 @@ class EngineArgs:
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
enable_prefix_caching: bool = False
use_v2_block_manager: bool = False
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
@ -52,6 +53,9 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False
forced_num_gpu_blocks: Optional[int] = None
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
image_token_id: Optional[int] = None
@ -194,6 +198,9 @@ class EngineArgs:
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='Enables automatic prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
parser.add_argument('--seed',
type=int,
@ -210,6 +217,12 @@ class EngineArgs:
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument(
'--forced-num-gpu-blocks',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
'of GPU blocks. Used for testing preemption.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
@ -369,6 +382,7 @@ class EngineArgs:
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
self.forced_num_gpu_blocks,
model_config.get_sliding_window(),
self.enable_prefix_caching)
parallel_config = ParallelConfig(
@ -383,6 +397,7 @@ class EngineArgs:
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
self.use_v2_block_manager,
self.scheduler_delay_factor)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,

View File

@ -553,12 +553,6 @@ class LLMEngine:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
# If prefix caching is enabled, mark all blocks in the sequence groups
# as completed so that future requests don't attempt to recompute them
if self.cache_config.enable_prefix_caching:
for seq_group in scheduled_seq_groups:
self.scheduler.mark_blocks_as_computed(seq_group)
for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_outputs(seq_group, outputs)

View File

@ -85,6 +85,12 @@ class GPUExecutor(ExecutorBase):
cache_dtype=self.cache_config.cache_dtype,
))
if self.cache_config.forced_num_gpu_blocks is not None:
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
f"{forced_num_gpu_blocks=}")
num_gpu_blocks = forced_num_gpu_blocks
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")

View File

@ -232,6 +232,13 @@ class RayGPUExecutor(ExecutorBase):
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
if self.cache_config.forced_num_gpu_blocks is not None:
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
f"{forced_num_gpu_blocks=}")
num_gpu_blocks = forced_num_gpu_blocks
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")

View File

@ -196,6 +196,8 @@ class Sequence:
return self.lora_request.lora_int_id if self.lora_request else 0
def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
# Compute the number of tokens in the sequence
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.

View File

@ -227,6 +227,16 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
def chunk_list(lst, chunk_size):
"""Yield successive chunk_size chunks from lst."""
return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
@lru_cache(maxsize=None)
def get_nvcc_cuda_version() -> Optional[Version]:
cuda_home = os.environ.get('CUDA_HOME')