diff --git a/tests/core/block/__init__.py b/tests/core/block/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py new file mode 100644 index 00000000..e1a9dd28 --- /dev/null +++ b/tests/core/block/e2e/conftest.py @@ -0,0 +1,56 @@ +import contextlib +import gc + +import pytest +import ray +import torch + +from vllm import LLM +from vllm.model_executor.parallel_utils.parallel_state import ( + destroy_model_parallel) +from vllm.model_executor.utils import set_random_seed + + +def cleanup(): + destroy_model_parallel() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() + + +@pytest.fixture +def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, seed): + return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, seed) + + +@pytest.fixture +def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + test_llm_kwargs, seed): + return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + test_llm_kwargs, seed) + + +def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + distinct_llm_kwargs, seed): + kwargs = { + **common_llm_kwargs, + **per_test_common_llm_kwargs, + **distinct_llm_kwargs, + } + + def generator_inner(): + llm = LLM(**kwargs) + + set_random_seed(seed) + + yield llm + del llm + cleanup() + + for llm in generator_inner(): + yield llm + del llm diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py new file mode 100644 index 00000000..283d99fe --- /dev/null +++ b/tests/core/block/e2e/test_correctness.py @@ -0,0 +1,86 @@ +from itertools import cycle + +import pytest + +from vllm import SamplingParams + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "forced_num_gpu_blocks": 5 * (64 + 1), + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify block manager v2 produces same outputs as block manager v1, even + when there is preemption. + + This constructs two LLM, each with limited number of GPU blocks. The limit + is decided such that as the sequences in the batch grow, sequences must be + preempted and removed from cache. + + If the output token ids are equivalent, then we have confidence that the KV + cache is not corrupted in the v2 block manager. + + NOTE: We want a significant number of generated tokens so that any incorrect + KV mapping has time to build up error. + """ + output_len = 1024 + temperature = 0.0 + + # We want to ensure equality even with preemption. + # We force the total block size to be 1 + cdiv(output_len, block_size) + # so that only one sequence can fit at a time (once the sequences grow). + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids from block manager v1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids from block manager v2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + +def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): + for llm in llm_generator: + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + del llm + + return token_ids diff --git a/tests/core/block/test_block_space_manager.py b/tests/core/block/test_block_space_manager.py new file mode 100644 index 00000000..eec8cbcb --- /dev/null +++ b/tests/core/block/test_block_space_manager.py @@ -0,0 +1,50 @@ +import pytest + +from vllm.core.block_manager_v2 import BlockSpaceManagerV2 +from vllm.core.interfaces import AllocStatus + +from ..utils import create_seq_group + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80]) +@pytest.mark.parametrize("num_seqs_per_group", [1, 4]) +@pytest.mark.parametrize("watermark", [0.0, 0.5]) +def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, + num_gpu_blocks: int, watermark: float): + block_manager = BlockSpaceManagerV2( + block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + watermark=watermark, + ) + num_watermark_blocks = int(watermark * num_gpu_blocks) + + num_output_blocks_per_seq = 1 + + # NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but + # the current implementation assumes all seqs are new prompts / don't have + # different output lens. + num_output_blocks = num_output_blocks_per_seq + + for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks): + seq_group = create_seq_group( + seq_prompt_lens=block_size * num_prompt_blocks, + seq_output_lens=[ + block_size * num_output_blocks_per_seq + for _ in range(num_seqs_per_group) + ], + ) + + assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks + + can_allocate_result = block_manager.can_allocate(seq_group) + + num_required_blocks = num_prompt_blocks + num_output_blocks + + if num_gpu_blocks - num_required_blocks < num_watermark_blocks: + assert can_allocate_result == AllocStatus.NEVER + elif num_gpu_blocks >= num_required_blocks: + assert can_allocate_result == AllocStatus.OK + else: + assert can_allocate_result == AllocStatus.LATER diff --git a/tests/core/block/test_block_table.py b/tests/core/block/test_block_table.py new file mode 100644 index 00000000..a7c5aa2b --- /dev/null +++ b/tests/core/block/test_block_table.py @@ -0,0 +1,500 @@ +import pytest + +from vllm.core.block.block_table import BlockTable +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.utils import Device, cdiv, chunk_list + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("sequence_len", [1, 16, 129]) +def test_allocate_naive(block_size: int, sequence_len: int): + """Test the allocation of blocks using the naive allocator. + + This test creates a CpuGpuBlockAllocator with the specified block size and + number of blocks. It then allocates multiple BlockTables with varying + sequence lengths and verifies that the number of free blocks decreases as + expected after each allocation. + """ + assert block_size > 1 + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type="naive", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) + + block_tables = [] + for i in range(5): + assert allocator.get_num_free_blocks( + device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc + + block_tables.append( + BlockTable( + block_size=block_size, + block_allocator=allocator, + )) + block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("sequence_len", [1, 16, 129]) +def test_allocate_prefix_caching(block_size: int, sequence_len: int): + """Test the allocation of blocks using the prefix caching allocator. + + This test creates a CpuGpuBlockAllocator with the specified block size and + number of blocks, using the prefix caching allocator. It then allocates + multiple BlockTables with varying sequence lengths and verifies that the + number of free blocks decreases as expected after each allocation. + + The test expects all sequences to share allocations, except for their last + block, which may be mutable. It calculates the expected number of immutable + and mutable blocks per allocation based on the sequence length and block + size. + """ + assert block_size > 1 + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + chunked_tokens = list(chunk_list(token_ids, block_size)) + num_mutable_blocks_per_alloc = 0 if len( + chunked_tokens[-1]) == block_size else 1 + num_immutable_blocks_per_alloc = len( + chunked_tokens) - num_mutable_blocks_per_alloc + + block_tables = [] + for alloc_i in range(1, 6): + + block_tables.append( + BlockTable( + block_size=block_size, + block_allocator=allocator, + )) + block_tables[-1].allocate(token_ids=token_ids, device=Device.GPU) + + # Expect all sequences to share allocations, except for their last block + # (which may be mutable). + assert allocator.get_num_free_blocks( + device=Device.GPU) == num_gpu_blocks - ( + num_immutable_blocks_per_alloc + num_mutable_blocks_per_alloc * + (alloc_i)) + + +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("sequence_len", [1, 16, 129]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +@pytest.mark.parametrize("device", ["cpu", "gpu"]) +def test_allocate_free(block_size: int, sequence_len: int, allocator_type: str, + device: str): + """Test the allocation and freeing of blocks using different allocators and + devices. + + This test creates a CpuGpuBlockAllocator with the specified block size, + number of blocks, allocator type, and device. It then allocates a BlockTable + multiple times with the same sequence and verifies that the number of free + blocks remains consistent after each allocation and freeing. + """ + device = Device[device.upper()] + + num_device_blocks = 1024 + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_device_blocks, + num_cpu_blocks=num_device_blocks, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) + + block_table = BlockTable( + block_size=block_size, + block_allocator=allocator, + ) + + for i in range(5): + block_table.allocate(token_ids=token_ids, device=device) + assert allocator.get_num_free_blocks( + device) == num_device_blocks - num_blocks_per_alloc + assert all(block_id is not None + for block_id in block_table.physical_block_ids) + + block_table.free() + assert allocator.get_num_free_blocks(device) == num_device_blocks + + +@pytest.mark.parametrize("block_size", [1, 8]) +@pytest.mark.parametrize("sequence_len", [1, 16, 129]) +@pytest.mark.parametrize("append_len", [1, 16, 129]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_append_token_ids_allocation(block_size: int, sequence_len: int, + append_len: int, allocator_type: str): + """Test the allocation behavior when appending token IDs to a BlockTable. + + This test creates a CpuGpuBlockAllocator with the specified block size, + number of blocks, and allocator type. It then allocates a BlockTable with an + initial sequence and appends additional token IDs to it. The test verifies + that the number of allocated blocks before and after appending matches the + expected values. + """ + + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + token_ids_to_append = list(range(append_len)) + + block_table = BlockTable( + block_size=block_size, + block_allocator=allocator, + ) + + num_expected_blocks_before_append = len( + list(chunk_list(token_ids, block_size))) + num_expected_appended_blocks = len( + list(chunk_list(token_ids + token_ids_to_append, + block_size))) - num_expected_blocks_before_append + + block_table.allocate(token_ids=token_ids, device=Device.GPU) + + assert len( + block_table.physical_block_ids) == num_expected_blocks_before_append + block_table.append_token_ids(token_ids_to_append) + assert len( + block_table.physical_block_ids + ) == num_expected_blocks_before_append + num_expected_appended_blocks + + +@pytest.mark.parametrize("block_size", [1, 8]) +@pytest.mark.parametrize("sequence_len", [1, 16, 129]) +@pytest.mark.parametrize("num_empty_slots", [1, 16, 129]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_ensure_num_empty_slots_allocation(block_size: int, sequence_len: int, + num_empty_slots: int, + allocator_type: str): + """Test the allocation behavior when ensuring a certain number of empty + slots in a BlockTable. + + This test creates a CpuGpuBlockAllocator with the specified block size, + number of blocks, and allocator type. It then allocates a BlockTable with an + initial sequence and ensures a certain number of empty slots. The test + verifies that the number of allocated blocks before and after ensuring empty + slots matches the expected values. It also checks that filling up the empty + slots does not consume additional blocks. + """ + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + + block_table = BlockTable( + block_size=block_size, + block_allocator=allocator, + ) + + num_expected_blocks_before_append = len( + list(chunk_list(token_ids, block_size))) + num_expected_appended_blocks = len( + list(chunk_list(token_ids + [-1] * num_empty_slots, + block_size))) - num_expected_blocks_before_append + + block_table.allocate(token_ids=token_ids, device=Device.GPU) + + # Assert that the empty slots consume the expected number of additional + # blocks. + assert len( + block_table.physical_block_ids) == num_expected_blocks_before_append + block_table.ensure_num_empty_slots(num_empty_slots) + assert len( + block_table.physical_block_ids + ) == num_expected_blocks_before_append + num_expected_appended_blocks + + # Now, ensure no additional blocks consumed as we fill up the empty slots. + num_free_blocks = allocator.get_num_free_blocks(device=Device.GPU) + block_table.append_token_ids(token_ids=list(range(num_empty_slots))) + assert num_free_blocks == allocator.get_num_free_blocks(device=Device.GPU) + + +@pytest.mark.parametrize("block_size", [1, 8]) +@pytest.mark.parametrize("sequence_len", [1, 9]) +@pytest.mark.parametrize("append_len", [1, 16, 129]) +@pytest.mark.parametrize("append_size", [1, 4, 129]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_append_token_ids_correct_content(block_size: int, sequence_len: int, + append_len: int, allocator_type: str, + append_size: int): + """Verify token ids are correctly appended. Appends various amounts of + token ids in various append sizes, and verifies the final sequence is + correct. + """ + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=1024, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + token_ids_to_append = list(range(append_len)) + + block_table = BlockTable( + block_size=block_size, + block_allocator=allocator, + ) + block_table.allocate(token_ids=token_ids, device=Device.GPU) + + appended_so_far = [] + for append in chunk_list(token_ids_to_append, append_size): + block_table.append_token_ids(append) + appended_so_far.extend(append) + + assert block_table._get_all_token_ids() == token_ids + appended_so_far + + assert block_table._get_all_token_ids() == token_ids + token_ids_to_append + + +@pytest.mark.parametrize("seq_len", [1, 9, 129]) +@pytest.mark.parametrize("block_size", [1, 8]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_fork(seq_len: int, block_size: int, allocator_type: str): + """Create a sequence using the specified allocator. + 1. Assert that after forking the sequence, the free block count is the + same. + 2. Assert that the forked sequence has the same physical mappings. + 3. Then free the original sequence; verify that the free block count is + the same. + 4. Finally, free the forked sequence and verify that the free block + count drops to zero. + """ + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + block_size=block_size, + ) + + token_ids = list(range(seq_len)) + + block_table = BlockTable( + block_size=block_size, + block_allocator=allocator, + ) + + block_table.allocate(token_ids) + + num_free_blocks_before_fork = allocator.get_num_free_blocks( + device=Device.GPU) + + forked_block_table = block_table.fork() + + # Expect physical_block_ids and token_ids to match. + assert (block_table.physical_block_ids == + forked_block_table.physical_block_ids) + assert block_table._get_all_token_ids( + ) == forked_block_table._get_all_token_ids() + + # Do not expect any additional allocations. + assert allocator.get_num_free_blocks( + device=Device.GPU) == num_free_blocks_before_fork + + # Free the original blocks. Assert num free blocks does not change, since + # refcount is nonzero. + block_table.free() + assert allocator.get_num_free_blocks( + device=Device.GPU) == num_free_blocks_before_fork + + # Expect the forked block table to be unaffected by the free. + assert all(block_id is not None + for block_id in forked_block_table.physical_block_ids) + + # Free the forked blocks. Assert num free blocks does change, since + # refcount is now zero. + forked_block_table.free() + assert allocator.get_num_free_blocks(device=Device.GPU) == num_gpu_blocks + + +@pytest.mark.parametrize("block_size", [8]) +@pytest.mark.parametrize("sequence_len", [1, 16, 129]) +@pytest.mark.parametrize("append_len", [1, 16, 129]) +@pytest.mark.parametrize("appender", ["forked", "original"]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_cow(block_size: int, sequence_len: int, append_len: int, + allocator_type: str, appender: str): + """Fork a sequence; append to the forked sequence; verify there's a CoW. + """ + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + token_ids_to_append = list(range(append_len)) + + original_block_table = BlockTable( + block_size=block_size, + block_allocator=allocator, + ) + + num_expected_non_cow_blocks = cdiv(sequence_len, block_size) + num_expected_cow_blocks = cdiv(sequence_len + append_len, + block_size) - (sequence_len // block_size) + + original_block_table.allocate(token_ids=token_ids, device=Device.GPU) + original_block_ids = original_block_table.physical_block_ids + + forked_block_table = original_block_table.fork() + + # Expect no additional allocation (copy on _write_). + assert allocator.get_num_free_blocks( + Device.GPU) == (num_gpu_blocks - num_expected_non_cow_blocks) + + if appender == "forked": + appender_block_table = forked_block_table + static_block_table = original_block_table + elif appender == "original": + appender_block_table = original_block_table + static_block_table = forked_block_table + else: + raise ValueError(f"unknown test config {appender=}") + + # Write tokens. + appender_block_table.append_token_ids(token_ids_to_append) + + # Expect the non-appending block table to have no change. + assert static_block_table.physical_block_ids == original_block_ids + assert appender_block_table.physical_block_ids != original_block_ids + + # Expect the blocks changed during append to have a CoW. + assert allocator.get_num_free_blocks( + Device.GPU) == num_gpu_blocks - (num_expected_non_cow_blocks + + num_expected_cow_blocks) + + cows = allocator.clear_copy_on_writes() + if sequence_len % block_size > 0: + # If the last block in the sequence is not full, then when appending we + # expect a CoW. + assert cows + + cow_block_id = sequence_len // block_size + expected_src = static_block_table.physical_block_ids[cow_block_id] + expected_dst = appender_block_table.physical_block_ids[cow_block_id] + + assert expected_src in cows + assert expected_dst in cows[expected_src] + else: + # Otherwise, there should be no copy-on-write. + assert not cows + + static_block_table.free() + appender_block_table.free() + + # After free, expect all blocks to be freed. + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + +@pytest.mark.parametrize("block_size", [8]) +@pytest.mark.parametrize("sequence_len", [1, 16, 129]) +@pytest.mark.parametrize("append_len", [1, 16, 129]) +@pytest.mark.parametrize("lookahead_slots", [1, 16, 129]) +@pytest.mark.parametrize("appender", ["forked", "original"]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_cow_lookahead_simple(block_size: int, sequence_len: int, + append_len: int, lookahead_slots: int, + allocator_type: str, appender: str): + """Similar to test_cow, except with lookahead allocation. The assertions are + less rigorous due to the complexity of the property under test. + """ + num_gpu_blocks = 1024 + + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=0, + block_size=block_size, + ) + + token_ids = list(range(sequence_len)) + token_ids_to_append = list(range(append_len)) + + original_block_table = BlockTable( + block_size=block_size, + block_allocator=allocator, + ) + + original_block_table.allocate(token_ids=token_ids, device=Device.GPU) + + # Allocate lookahead slots. + original_block_table.ensure_num_empty_slots(lookahead_slots) + original_block_ids = original_block_table.physical_block_ids + + forked_block_table = original_block_table.fork() + + if appender == "forked": + appender_block_table = forked_block_table + static_block_table = original_block_table + elif appender == "original": + appender_block_table = original_block_table + static_block_table = forked_block_table + else: + raise ValueError(f"unknown test config {appender=}") + + # Write tokens. + appender_block_table.append_token_ids(token_ids_to_append) + + # Expect the non-appending block table to have no change. + assert static_block_table.physical_block_ids == original_block_ids + assert appender_block_table.physical_block_ids != original_block_ids + + cows = allocator.clear_copy_on_writes() + + # Always expect copy-on-write + assert cows + + if sequence_len % block_size > 0: + # If the last block in the sequence is not full, then when appending we + # expect a CoW. + assert cows + + cow_block_id = sequence_len // block_size + expected_src = static_block_table.physical_block_ids[cow_block_id] + expected_dst = appender_block_table.physical_block_ids[cow_block_id] + + assert expected_src in cows + assert expected_dst in cows[expected_src] + + static_block_table.free() + appender_block_table.free() + + # After free, expect all blocks to be freed. + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks diff --git a/tests/core/block/test_common.py b/tests/core/block/test_common.py new file mode 100644 index 00000000..cfdd3582 --- /dev/null +++ b/tests/core/block/test_common.py @@ -0,0 +1,42 @@ +import random + +import pytest + +from vllm.core.block.common import RefCounter + + +@pytest.mark.parametrize("seed", list(range(20))) +@pytest.mark.parametrize("num_incrs", [1, 100]) +@pytest.mark.parametrize("num_blocks", [1024]) +def test_incr(seed: int, num_incrs: int, num_blocks: int): + random.seed(seed) + + all_block_indices = list(range(num_blocks)) + counter = RefCounter(all_block_indices=all_block_indices) + + block_id = random.randint(0, num_blocks - 1) + for i in range(num_incrs): + value = counter.incr(block_id) + assert value == i + 1 + + +@pytest.mark.parametrize("seed", list(range(20))) +@pytest.mark.parametrize("num_incrs", [1, 100]) +@pytest.mark.parametrize("num_blocks", [1024]) +def test_incr_decr(seed: int, num_incrs: int, num_blocks: int): + random.seed(seed) + + all_block_indices = list(range(num_blocks)) + counter = RefCounter(all_block_indices=all_block_indices) + + block_id = random.randint(0, num_blocks - 1) + for i in range(num_incrs): + value = counter.incr(block_id) + assert value == i + 1 + + for i in range(num_incrs): + value = counter.decr(block_id) + assert value == num_incrs - (i + 1) + + with pytest.raises(AssertionError): + counter.decr(block_id) diff --git a/tests/core/block/test_cpu_gpu_block_allocator.py b/tests/core/block/test_cpu_gpu_block_allocator.py new file mode 100644 index 00000000..44a5be6c --- /dev/null +++ b/tests/core/block/test_cpu_gpu_block_allocator.py @@ -0,0 +1,93 @@ +import pytest + +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.utils import Device, chunk_list + + +@pytest.mark.parametrize("num_cpu_blocks", [0, 512]) +@pytest.mark.parametrize("num_gpu_blocks", [1024]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int, + block_size: int, allocator_type: str): + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + cpu_blocks = [ + allocator.allocate_mutable(prev_block=None, device=Device.CPU) + for _ in range(num_cpu_blocks) + ] + assert allocator.get_num_free_blocks(Device.CPU) == 0 + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + gpu_blocks = [ + allocator.allocate_mutable(prev_block=None, device=Device.GPU) + for _ in range(num_gpu_blocks) + ] + assert allocator.get_num_free_blocks(Device.CPU) == 0 + assert allocator.get_num_free_blocks(Device.GPU) == 0 + + _ = [allocator.free(block) for block in cpu_blocks] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == 0 + + _ = [allocator.free(block) for block in gpu_blocks] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + +@pytest.mark.parametrize("num_cpu_blocks", [0, 512]) +@pytest.mark.parametrize("num_gpu_blocks", [1024]) +@pytest.mark.parametrize("block_size", [2]) +@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"]) +def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int, + block_size: int, allocator_type: str): + allocator = CpuGpuBlockAllocator.create( + allocator_type=allocator_type, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + unique_token_ids = list( + range((num_cpu_blocks + num_gpu_blocks) * block_size)) + gpu_token_ids = chunk_list(unique_token_ids[:num_gpu_blocks * block_size], + block_size) + cpu_token_ids = chunk_list(unique_token_ids[num_gpu_blocks * block_size:], + block_size) + + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + cpu_blocks = [ + allocator.allocate_immutable(prev_block=None, + token_ids=token_ids, + device=Device.CPU) + for token_ids in cpu_token_ids + ] + assert allocator.get_num_free_blocks(Device.CPU) == 0 + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks + + gpu_blocks = [ + allocator.allocate_immutable(prev_block=None, + token_ids=token_ids, + device=Device.GPU) + for token_ids in gpu_token_ids + ] + assert allocator.get_num_free_blocks(Device.CPU) == 0 + assert allocator.get_num_free_blocks(Device.GPU) == 0 + + _ = [allocator.free(block) for block in cpu_blocks] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == 0 + + _ = [allocator.free(block) for block in gpu_blocks] + assert allocator.get_num_free_blocks(Device.CPU) == num_cpu_blocks + assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks diff --git a/tests/core/block/test_naive_block.py b/tests/core/block/test_naive_block.py new file mode 100644 index 00000000..edcdc0c7 --- /dev/null +++ b/tests/core/block/test_naive_block.py @@ -0,0 +1,102 @@ +from typing import List, Optional + +import pytest + +from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator + + +class TestNaiveBlockAllocator: + + @staticmethod + def create_allocate_lambda(allocate_type: str, + allocator: NaiveBlockAllocator, + prev_block: Optional[Block], + token_ids: List[int]): + if allocate_type == "immutable": + allocate_block = lambda: allocator.allocate_immutable( + prev_block=prev_block, token_ids=token_ids) + elif allocate_type == "mutable": + allocate_block = lambda: allocator.allocate_mutable(prev_block= + prev_block) + else: + raise ValueError() + + return allocate_block + + @staticmethod + @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) + @pytest.mark.parametrize("num_blocks", [1, 1024]) + @pytest.mark.parametrize("block_size", [1, 16]) + def test_allocate_ooms(allocate_type: str, num_blocks: int, + block_size: int): + allocator = NaiveBlockAllocator(create_block=NaiveBlock, + num_blocks=num_blocks, + block_size=block_size) + allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( + allocate_type, + allocator, + prev_block=None, + token_ids=list(range(block_size))) + + [allocate_block() for _ in range(num_blocks)] + with pytest.raises(BlockAllocator.NoFreeBlocksError): + allocate_block() + + @staticmethod + @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) + @pytest.mark.parametrize("num_blocks", [1, 1024]) + @pytest.mark.parametrize("block_size", [1, 16]) + def test_free_prevents_oom(allocate_type: str, num_blocks: int, + block_size: int): + allocator = NaiveBlockAllocator(create_block=NaiveBlock, + num_blocks=num_blocks, + block_size=block_size) + allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( + allocate_type, + allocator, + prev_block=None, + token_ids=list(range(block_size))) + + blocks = [allocate_block() for _ in range(num_blocks)] + + with pytest.raises(BlockAllocator.NoFreeBlocksError): + allocate_block() + + block_to_free = blocks.pop() + + for _ in range(100): + block_id = block_to_free.block_id + allocator.free(block_to_free) + assert block_to_free.block_id is None + + new_block = allocate_block() + assert new_block.block_id == block_id + + with pytest.raises(BlockAllocator.NoFreeBlocksError): + allocate_block() + + block_to_free = new_block + + @staticmethod + @pytest.mark.parametrize("allocate_type", ["immutable", "mutable"]) + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + def test_get_num_free_blocks(allocate_type: str, num_blocks: int, + block_size: int): + allocator = NaiveBlockAllocator(create_block=NaiveBlock, + num_blocks=num_blocks, + block_size=block_size) + allocate_block = TestNaiveBlockAllocator.create_allocate_lambda( + allocate_type, + allocator, + prev_block=None, + token_ids=list(range(block_size))) + + assert allocator.get_num_free_blocks() == num_blocks + + blocks = [allocate_block() for _ in range(num_blocks)] + + for i, block in enumerate(blocks): + assert allocator.get_num_free_blocks() == i + allocator.free(block) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py new file mode 100644 index 00000000..5f4d58dd --- /dev/null +++ b/tests/core/block/test_prefix_caching_block.py @@ -0,0 +1,384 @@ +import math +import random +from typing import List, Optional +from unittest.mock import MagicMock + +import pytest + +from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.prefix_caching_block import (PrefixCachingBlock, + PrefixCachingBlockAllocator) + + +class TestPrefixCachingBlock: + + @staticmethod + @pytest.mark.parametrize("seed", list(range(10))) + @pytest.mark.parametrize("block_size", [1, 16]) + @pytest.mark.parametrize("is_curr_block_full", [True, False]) + def test_first_block_has_correct_content_hash(seed: int, block_size: int, + is_curr_block_full: bool): + """Verify a block which is first in the sequence has the correct hash. + """ + random.seed(seed) + num_to_fill = block_size if is_curr_block_full else random.randint( + 0, block_size - 1) + token_ids = list(range(num_to_fill)) + mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) + + block_with_prev = PrefixCachingBlock( + prev_block=None, + token_ids=token_ids, + block_size=block_size, + prefix_caching_allocator=mock_allocator) + + if is_curr_block_full: + # Expect hash since block is full. + assert block_with_prev.content_hash == ( + PrefixCachingBlock.hash_block_tokens( + is_first_block=True, + prev_block_hash=None, + cur_block_token_ids=token_ids)) + else: + # Do not expect hash since block is not full. + assert block_with_prev.content_hash is None + + @staticmethod + @pytest.mark.parametrize("seed", list(range(10))) + @pytest.mark.parametrize("block_size", [1, 16]) + @pytest.mark.parametrize("is_curr_block_full", [True, False]) + @pytest.mark.parametrize("prev_block_has_hash", [True, False]) + def test_nth_block_has_correct_content_hash(seed: int, block_size: int, + is_curr_block_full: bool, + prev_block_has_hash: bool): + """Verify a block which is not first in the sequence has the correct + hash. + """ + + random.seed(seed) + + previous_block = MagicMock(spec=PrefixCachingBlock) + prev_block_hash = random.randint(0, 1000) + previous_block.content_hash = (prev_block_hash + if prev_block_has_hash else None) + + num_to_fill = block_size if is_curr_block_full else random.randint( + 0, block_size - 1) + token_ids = list(range(num_to_fill)) + mock_allocator = MagicMock(spec=PrefixCachingBlockAllocator) + + block_with_prev = PrefixCachingBlock( + prev_block=previous_block, + token_ids=token_ids, + block_size=block_size, + prefix_caching_allocator=mock_allocator, + ) + + if is_curr_block_full and prev_block_has_hash: + # Expect hash since block is full and previous block has hash. + assert (block_with_prev.content_hash == + PrefixCachingBlock.hash_block_tokens( + is_first_block=False, + prev_block_hash=prev_block_hash, + cur_block_token_ids=token_ids)) + else: + # Do not expect hash since block is not full or the previous block + # does not have a hash. + assert block_with_prev.content_hash is None + + @staticmethod + @pytest.mark.parametrize("block_size", [1, 2, 16]) + @pytest.mark.parametrize("num_tokens", list(range(3))) + @pytest.mark.parametrize("num_empty_trailing_blocks", [0, 1, 10]) + def test_blocks_have_correct_hash_in_chain(block_size: int, + num_tokens: int, + num_empty_trailing_blocks: int): + """Create two chains of logical blocks with the same contents. + Assert the hashes are equal. + """ + random.seed(0) + + token_ids = [random.randint(0, 50_000) for _ in range(num_tokens)] + + first_chain, second_chain = [ + TestPrefixCachingBlock.create_chain( + block_size=block_size, + token_ids=token_ids, + num_empty_trailing_blocks=num_empty_trailing_blocks) + for _ in range(2) + ] + + for first_chain_block, second_chain_block in zip( + first_chain, second_chain): + assert (first_chain_block.content_hash == + second_chain_block.content_hash) + + if not first_chain or not second_chain: + assert first_chain == second_chain + assert num_tokens == 0 + + @staticmethod + def create_chain(block_size: int, + token_ids: List[int], + num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]: + """Helper method which creates a chain of blocks. + """ + blocks = [] + num_blocks = math.ceil( + len(token_ids) / block_size) + num_empty_trailing_blocks + + if num_blocks == 0: + return [] + + allocator = MagicMock(spec=PrefixCachingBlockAllocator) + + prev_block = None + for block_number in range(0, num_blocks): + prev_block = PrefixCachingBlock( + prev_block=prev_block, + token_ids=[], + block_size=block_size, + prefix_caching_allocator=allocator, + ) + + tokens_to_append = token_ids[block_number * + block_size:(block_number + 1) * + block_size] + if tokens_to_append: + prev_block.append_token_ids(tokens_to_append) + + blocks.append(prev_block) + + return blocks + + +class TestPrefixCachingBlockAllocator: + + @staticmethod + def create_allocate_lambda(allocate_type: str, allocator: BlockAllocator, + prev_block: Optional[Block], + token_ids: List[int]): + if allocate_type == "immutable": + allocate_block = lambda: allocator.allocate_immutable( + prev_block=prev_block, token_ids=token_ids) + elif allocate_type == "mutable": + allocate_block = lambda: allocator.allocate_mutable(prev_block= + prev_block) + else: + raise ValueError() + + return allocate_block + + @staticmethod + @pytest.mark.parametrize("num_blocks", [1, 1024]) + @pytest.mark.parametrize("block_size", [1, 16]) + def test_allocate_mutable_ooms(num_blocks: int, block_size: int): + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda( + allocate_type="mutable", + allocator=allocator, + prev_block=None, + token_ids=list(range(block_size)), + ) + + [allocate_block() for _ in range(num_blocks)] + with pytest.raises(BlockAllocator.NoFreeBlocksError): + allocate_block() + + @staticmethod + @pytest.mark.parametrize("num_blocks", [1, 1024]) + @pytest.mark.parametrize("block_size", [1, 16]) + def test_allocate_immutable_does_not_oom_single_hash( + num_blocks: int, block_size: int): + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + allocate_block = TestPrefixCachingBlockAllocator.create_allocate_lambda( + allocate_type="immutable", + allocator=allocator, + prev_block=None, + token_ids=list(range(block_size)), + ) + + blocks = [allocate_block() for _ in range(num_blocks)] + + # Expect no OOM. If these were mutable blocks, this would OOM. + non_oom_block = allocate_block() + + # Expect all blocks to have same physical block index. + for block in blocks: + assert (block.block_id == non_oom_block.block_id) + + @staticmethod + @pytest.mark.parametrize("num_blocks", [1, 1024]) + @pytest.mark.parametrize("block_size", [1, 16]) + def test_allocate_immutable_ooms_many_hash(num_blocks: int, + block_size: int): + """Consume all blocks using many different hashes/block content. + + Do this by creating a sequence that is very long. + Expect next block to OOM. + """ + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + + # Create token ids that will exhaust all blocks. + token_ids = list(range(num_blocks * block_size)) + + chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # Expect allocation with unseen hash to fail. + with pytest.raises(BlockAllocator.NoFreeBlocksError): + allocator.allocate_immutable(prev_block=chain[-1], + token_ids=list(range(block_size))) + + # Expect mutable allocation to fail. + with pytest.raises(BlockAllocator.NoFreeBlocksError): + allocator.allocate_mutable(prev_block=chain[-1]) + + # Expect allocation of exact same chain to pass. + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # Expect physical block indices to be the same in both chains. + assert chain and second_chain + for first_chain_block, second_chain_block in zip(chain, second_chain): + assert (first_chain_block.block_id == second_chain_block.block_id) + + @staticmethod + @pytest.mark.parametrize("num_blocks", [1, 1024]) + @pytest.mark.parametrize("block_size", [1, 16]) + def test_free_prevents_oom(num_blocks: int, block_size: int): + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + + # Create token ids that will exhaust all blocks. + token_ids = list(range(num_blocks * block_size)) + + chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # Expect mutable allocation to fail. + with pytest.raises(BlockAllocator.NoFreeBlocksError): + allocator.allocate_mutable(prev_block=None) + + block_to_free = chain[-1] + + # Expect free/allocate loop to succeed many times. + for i in range(100): + block_id = block_to_free.block_id + allocator.free(block_to_free) + assert block_to_free.block_id is None, i + + new_block = allocator.allocate_mutable(prev_block=None) + assert new_block.block_id == block_id, i + + with pytest.raises(BlockAllocator.NoFreeBlocksError): + allocator.allocate_mutable(prev_block=None) + + block_to_free = new_block + + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int): + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + num_blocks_to_consume = random.randint(1, num_blocks - 1) + + # Create token ids that will exhaust all blocks. + token_ids = list(range(num_blocks_to_consume * block_size)) + + chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # Free each block in chain, assert num free blocks includes new free + # block. + for i, block in enumerate(chain): + assert allocator.get_num_free_blocks() == (num_blocks - + num_blocks_to_consume + + i) + allocator.free(block) + + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, + seed: int): + """Verify sharing occurs by allocating two sequences that share prefixes + and incrementally freeing blocks. + """ + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + num_blocks_to_consume = random.randint(1, num_blocks - 1) + + # Create token ids that will exhaust all blocks. + token_ids = list(range(num_blocks_to_consume * block_size)) + + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # Free each block in the first chain. Since all blocks are shared, the + # free count should stay constant. + for i, block in enumerate(first_chain): + assert allocator.get_num_free_blocks() == (num_blocks - + num_blocks_to_consume) + allocator.free(block) + + # Free each block in the second chain. Since the refcount is now zero, + # the free count should increment with each free. + for i, block in enumerate(second_chain): + assert allocator.get_num_free_blocks() == (num_blocks - + num_blocks_to_consume + + i) + allocator.free(block) + + @staticmethod + def create_immutable_chain( + block_size: int, + token_ids: List[int], + allocator: PrefixCachingBlockAllocator, + ) -> List[PrefixCachingBlock]: + """Helper method which creates a chain of blocks. + """ + blocks = [] + num_blocks = math.ceil(len(token_ids) / block_size) + + if num_blocks == 0: + return [] + + prev_block = None + for block_number in range(0, num_blocks): + block_token_ids = token_ids[block_number * + block_size:(block_number + 1) * + block_size] + prev_block = allocator.allocate_immutable( + prev_block=prev_block, token_ids=block_token_ids) + blocks.append(prev_block) + + return blocks diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index ee8e4389..93226cba 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -5,8 +5,9 @@ import pytest from vllm import SamplingParams from vllm.block import PhysicalTokenBlock -from vllm.core.block_manager import (AllocStatus, BlockSpaceManager, - UncachedBlockAllocator) +from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, + UncachedBlockAllocator) +from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -63,10 +64,10 @@ def test_allocate(): block_size = 4 num_cpu_blocks = 4 num_gpu_blocks = 4 - block_manager = BlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) # Allocate same sequence group to all available gpu blocks. for i in range(num_gpu_blocks): @@ -77,10 +78,10 @@ def test_allocate(): # Allocate same sequence group to all available gpu blocks. # Use watermark to reserve one gpu block. - block_manager = BlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=1 / num_gpu_blocks) + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=1 / num_gpu_blocks) for i in range(num_gpu_blocks - 1): _, seq_group = create_dummy_prompt(str(i), block_size) assert block_manager.can_allocate(seq_group) @@ -92,10 +93,10 @@ def test_append_slot_single_seq(): block_size = 4 num_cpu_blocks = 4 num_gpu_blocks = 4 - block_manager = BlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) # Allocate single seq to gpu block. prompt, seq_group = create_dummy_prompt("1", block_size) @@ -124,10 +125,10 @@ def test_append_slot_cow(): block_size = 4 num_cpu_blocks = 4 num_gpu_blocks = 4 - block_manager = BlockSpaceManager(block_size=block_size, - num_cpu_blocks=num_cpu_blocks, - num_gpu_blocks=num_gpu_blocks, - watermark=0) + block_manager = BlockSpaceManagerV1(block_size=block_size, + num_cpu_blocks=num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + watermark=0) # Allocate prompt to gpu block. There is one slot left in the block. prompt = Sequence(seq_id=1, @@ -165,10 +166,10 @@ def test_fork(): block_size = 4 num_cpu_blocks = 4 num_gpu_blocks = 4 - block_manager = BlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) prompt, seq_group = create_dummy_prompt("1", block_size - 1, @@ -192,10 +193,10 @@ def test_swap(): block_size = 4 num_cpu_blocks = 4 num_gpu_blocks = 4 - block_manager = BlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) prompt.status = SequenceStatus.WAITING @@ -238,10 +239,10 @@ def test_free(): block_size = 4 num_cpu_blocks = 4 num_gpu_blocks = 4 - block_manager = BlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) prompt, seq_group = create_dummy_prompt("1", block_size) block_manager.allocate(seq_group) @@ -262,10 +263,10 @@ def test_reset(): block_size = 4 num_cpu_blocks = 4 num_gpu_blocks = 4 - block_manager = BlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) # Allocate same seq group on all available gpu blocks. original_blocks = block_manager.get_num_free_gpu_blocks() @@ -289,11 +290,11 @@ def test_sliding_window_multi_seq(): num_cpu_blocks = 8 num_gpu_blocks = 8 sliding_window = 2 - block_manager = BlockSpaceManager(block_size, - num_cpu_blocks, - num_gpu_blocks, - sliding_window=sliding_window, - watermark=0) + block_manager = BlockSpaceManagerV1(block_size, + num_cpu_blocks, + num_gpu_blocks, + sliding_window=sliding_window, + watermark=0) assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index c66809c6..f40969cf 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -13,7 +13,7 @@ from .utils import create_dummy_prompt def test_scheduler_add_seq_group(): block_size = 4 scheduler_config = SchedulerConfig(100, 64, 1) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 scheduler = Scheduler(scheduler_config, cache_config, None) diff --git a/tests/core/utils.py b/tests/core/utils.py index 6469789e..2e462f2a 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -2,7 +2,7 @@ import time from typing import Tuple from vllm import SamplingParams -from vllm.sequence import Sequence, SequenceGroup +from vllm.sequence import Logprob, Sequence, SequenceGroup def create_dummy_prompt( @@ -23,5 +23,42 @@ def create_dummy_prompt( return prompt, seq_group +def create_seq_group( + seq_prompt_lens=1024, + seq_output_lens=(128, ), + request_id='0', + seq_id_start=0, +) -> SequenceGroup: + + assert len(seq_output_lens) > 0 + + prompt_token_ids = [0] * seq_prompt_lens + + seqs = [] + for seq_id_offset, output_len in enumerate(seq_output_lens): + seq = Sequence( + seq_id=seq_id_start + seq_id_offset, + prompt="", + prompt_token_ids=prompt_token_ids, + block_size=16, + ) + + for i in range(output_len): + seq.append_token_id( + token_id=i, + logprobs={i: Logprob(0.0)}, + ) + seqs.append(seq) + + seq_group = SequenceGroup( + request_id=request_id, + seqs=seqs, + sampling_params=SamplingParams(), + arrival_time=time.time(), + ) + + return seq_group + + def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index cb61aac3..305596e1 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -4,7 +4,7 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ import pytest -from vllm.core.block_manager import CachedBlockAllocator +from vllm.core.block_manager_v1 import CachedBlockAllocator from vllm.utils import Device diff --git a/vllm/config.py b/vllm/config.py index baa37cda..5025b046 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -324,6 +324,8 @@ class CacheConfig: vLLM execution. swap_space: Size of the CPU swap space per GPU (in GiB). cache_dtype: Data type for kv cache storage. + forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the + profiled num_gpu_blocks if specified. Does nothing if None. """ def __init__( @@ -332,12 +334,14 @@ class CacheConfig: gpu_memory_utilization: float, swap_space: int, cache_dtype: str, + forced_num_gpu_blocks: Optional[int] = None, sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB + self.forced_num_gpu_blocks = forced_num_gpu_blocks self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching @@ -528,6 +532,7 @@ class SchedulerConfig: and generated text). delay_factor: Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt. + use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. """ def __init__( @@ -535,6 +540,7 @@ class SchedulerConfig: max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, + use_v2_block_manager: bool = False, delay_factor: float = 0.0, ) -> None: if max_num_batched_tokens is not None: @@ -546,6 +552,7 @@ class SchedulerConfig: self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len self.delay_factor = delay_factor + self.use_v2_block_manager = use_v2_block_manager self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py new file mode 100644 index 00000000..793c6698 --- /dev/null +++ b/vllm/core/block/block_table.py @@ -0,0 +1,245 @@ +from typing import List, Optional + +from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator +from vllm.utils import Device, cdiv, chunk_list + + +class BlockTable: + """A class to manage blocks for a specific sequence. + + The BlockTable maps a sequence of tokens to a list of blocks, where each + block represents a contiguous memory allocation for a portion of the + sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is + responsible for allocating and freeing memory for the blocks. + + Args: + block_size (int): The maximum number of tokens that can be stored in a + single block. + block_allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]], optional): An optional list of existing + blocks to initialize the BlockTable with. If not provided, an empty + BlockTable is created. + + Attributes: + _block_size (int): The maximum number of tokens that can be stored in a + single block. + _allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]]): The list of blocks managed by this + BlockTable. + _num_full_slots (int): The number of tokens currently stored in the + blocks. + """ + + def __init__( + self, + block_size: int, + block_allocator: DeviceAwareBlockAllocator, + _blocks: Optional[List[Block]] = None, + ): + self._block_size = block_size + self._allocator = block_allocator + self._blocks: Optional[List[Block]] = _blocks + + # Use helper method instead of directly calculating, as blocks + # may not be allocated. + self._num_full_slots = len(self._get_all_token_ids()) + + @staticmethod + def get_num_required_blocks(token_ids: List[int], block_size: int) -> int: + """Calculates the minimum number of blocks required to store a given + sequence of token IDs. + + This assumes worst-case scenario, where every block requires a new + allocation (e.g. ignoring prefix caching). + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + block_size (int): The maximum number of tokens that can be stored in + a single block. + + Returns: + int: The minimum number of blocks required to store the given + sequence of token IDs. + """ + return cdiv(len(token_ids), block_size) + + def allocate(self, + token_ids: List[int], + device: Device = Device.GPU) -> None: + """Allocates memory blocks for storing the given sequence of token IDs. + + This method allocates the required number of blocks to store the given + sequence of token IDs. + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + device (Device, optional): The device on which the blocks should be + allocated. Defaults to Device.GPU. + """ + assert not self._is_allocated + assert token_ids + self._blocks = self._allocate_blocks_for_token_ids(prev_block=None, + token_ids=token_ids, + device=device) + self._num_full_slots = len(token_ids) + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends a sequence of token IDs to the existing blocks in the + BlockTable. + + This method appends the given sequence of token IDs to the existing + blocks in the BlockTable. If there is not enough space in the existing + blocks, new blocks are allocated using the `ensure_num_empty_slots` + method to accommodate the additional tokens. + + The token IDs are divided into chunks of size `block_size` (except for + the first chunk, which may be smaller), and each chunk is appended to a + separate block. + + Args: + token_ids (List[int]): The sequence of token IDs to be appended. + """ + assert self._is_allocated + + self.ensure_num_empty_slots(num_empty_slots=len(token_ids)) + + blocks = self._blocks[self._num_full_slots // self._block_size:] + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + token_blocks = [token_ids[:first_chunk_size]] + chunk_list( + token_ids[first_chunk_size:], self._block_size) + + for block, token_block in zip(blocks, token_blocks): + block.append_token_ids(token_block) + + self._num_full_slots += len(token_ids) + + def ensure_num_empty_slots(self, num_empty_slots: int) -> None: + """Ensures that the BlockTable has at least the specified number of + empty slots available. + + This method checks if the BlockTable has enough empty slots (i.e., + available space) to accommodate the requested number of tokens. If not, + it allocates additional blocks on the GPU to ensure that the required + number of empty slots is available. + + Args: + num_empty_slots (int): The minimum number of empty slots required. + """ + # Currently the block table only supports + # appending tokens to GPU blocks. + device = Device.GPU + assert self._is_allocated + + if self._num_empty_slots >= num_empty_slots: + return + + slots_to_allocate = num_empty_slots - self._num_empty_slots + blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) + + for _ in range(blocks_to_allocate): + self._blocks.append( + self._allocator.allocate_mutable(prev_block=self._blocks[-1], + device=device)) + + def fork(self) -> "BlockTable": + """Creates a new BlockTable instance with a copy of the blocks from the + current instance. + + This method creates a new BlockTable instance with the same block size, + block allocator, and a copy of the blocks from the current instance. The + new BlockTable has its own independent set of blocks, but shares the + same underlying memory allocation with the original BlockTable. + + Returns: + BlockTable: A new BlockTable instance with a copy of the blocks from + the current instance. + """ + assert self._is_allocated + forked_blocks = self._allocator.fork(self._blocks[-1]) + return BlockTable( + block_size=self._block_size, + block_allocator=self._allocator, + _blocks=forked_blocks, + ) + + def free(self) -> None: + """Frees the memory occupied by the blocks in the BlockTable. + + This method iterates over all the blocks in the `_blocks` list and calls + the `free` method of the `_allocator` object to release the memory + occupied by each block. After freeing all the blocks, the `_blocks` list + is set to `None`. + """ + assert self._is_allocated + for block in self._blocks: + self._allocator.free(block) + self._blocks = None + + @property + def physical_block_ids(self) -> List[int]: + """Returns a list of physical block indices for the blocks in the + BlockTable. + + This property returns a list of integers, where each integer represents + the physical block index of a corresponding block in the `_blocks` list. + The physical block index is a unique identifier for the memory location + occupied by the block. + + Returns: + List[int]: A list of physical block indices for the blocks in the + BlockTable. + """ + assert self._is_allocated + return [block.block_id for block in self._blocks] + + def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], + token_ids: List[int], + device: Device) -> List[Block]: + blocks = [] + for block_token_ids in chunk_list(token_ids, self._block_size): + if len(block_token_ids) == self._block_size: + # If the block is full, create an immutable block. + prev_block = self._allocator.allocate_immutable( + prev_block, token_ids=block_token_ids, device=device) + else: + # Else, partially fill a mutable block with token ids. + prev_block = self._allocator.allocate_mutable( + prev_block=prev_block, device=device) + prev_block.append_token_ids(block_token_ids) + blocks.append(prev_block) + + return blocks + + def _get_all_token_ids(self) -> List[int]: + # NOTE: This function is O(seq_len); use sparingly. + token_ids = [] + + if not self._is_allocated: + return token_ids + + for block in self._blocks: + token_ids.extend(block.token_ids) + + return token_ids + + @property + def _is_allocated(self) -> bool: + return self._blocks is not None + + @property + def _num_empty_slots(self) -> int: + assert self._is_allocated + return len(self._blocks) * self._block_size - self._num_full_slots + + @property + def num_full_slots(self) -> int: + """Returns the total number of tokens currently stored in the + BlockTable. + + Returns: + int: The total number of tokens currently stored in the BlockTable. + """ + return self._num_full_slots diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py new file mode 100644 index 00000000..50c70533 --- /dev/null +++ b/vllm/core/block/common.py @@ -0,0 +1,185 @@ +from collections import defaultdict +from typing import Dict, Iterable, List, Optional + +from vllm.core.block.interfaces import Block, BlockAllocator + +BlockId = int +RefCount = int + + +class RefCounter: + """A class for managing reference counts for a set of block indices. + + The RefCounter class maintains a dictionary that maps block indices to their + corresponding reference counts. It provides methods to increment, decrement, + and retrieve the reference count for a given block index. + + Args: + all_block_indices (Iterable[BlockId]): An iterable of block indices + to initialize the reference counter with. + """ + + def __init__(self, all_block_indices: Iterable[BlockId]): + deduped = set(all_block_indices) + self._refcounts: Dict[BlockId, + RefCount] = {index: 0 + for index in deduped} + + def incr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + pre_incr_refcount = self._refcounts[block_id] + + assert pre_incr_refcount >= 0 + + post_incr_refcount = pre_incr_refcount + 1 + self._refcounts[block_id] = post_incr_refcount + return post_incr_refcount + + def decr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + refcount = self._refcounts[block_id] + + assert refcount > 0 + refcount -= 1 + + self._refcounts[block_id] = refcount + + return refcount + + def get(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + return self._refcounts[block_id] + + def as_readonly(self) -> "ReadOnlyRefCounter": + return ReadOnlyRefCounter(self) + + +class ReadOnlyRefCounter: + """A read-only view of the RefCounter class. + + The ReadOnlyRefCounter class provides a read-only interface to access the + reference counts maintained by a RefCounter instance. It does not allow + modifications to the reference counts. + + Args: + refcounter (RefCounter): The RefCounter instance to create a read-only + view for. + """ + + def __init__(self, refcounter: RefCounter): + self._refcounter = refcounter + + def incr(self, block_id: BlockId) -> RefCount: + raise ValueError("Incr not allowed") + + def decr(self, block_id: BlockId) -> RefCount: + raise ValueError("Decr not allowed") + + def get(self, block_id: BlockId) -> RefCount: + return self._refcounter.get(block_id) + + +class CopyOnWriteTracker: + """A class for tracking and managing copy-on-write operations for blocks. + + The CopyOnWriteTracker class maintains a mapping of source block indices to + their corresponding copy-on-write destination block indices. It works in + conjunction with a RefCounter and a BlockAllocator to handle reference + counting and block allocation. + + Args: + refcounter (RefCounter): The reference counter used to track block + reference counts. + allocator (BlockAllocator): The block allocator used to allocate and + free blocks. + """ + + def __init__( + self, + refcounter: RefCounter, + allocator: BlockAllocator, + ): + self._copy_on_writes = defaultdict(list) + self._refcounter = refcounter + self._allocator = allocator + + def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + This method checks the reference count of the given block. If the + reference count is greater than 1, indicating that the block is shared, + a copy-on-write operation is performed. The original block is freed, + and a new block is allocated with the same content. The new block index + is returned. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + Optional[BlockId]: The block index of the new block if a copy-on + -write operation was performed, or the original block index if + no copy-on-write was necessary. + """ + block_id = block.block_id + if block_id is None: + return block_id + + refcount = self._refcounter.get(block_id) + assert refcount != 0 + if refcount > 1: + src_block_id = block_id + + # Decrement refcount of the old block. + self._allocator.free(block) + + # Allocate a fresh new block. + block_id = self._allocator.allocate_mutable( + prev_block=block.prev_block).block_id + + # Track src/dst copy. + self._copy_on_writes[src_block_id].append(block_id) + + return block_id + + def clear_cows(self) -> Dict[BlockId, List[BlockId]]: + """Clears the copy-on-write tracking information and returns the current + state. + + This method returns a dictionary mapping source block indices to lists + of destination block indices for the current copy-on-write operations. + It then clears the internal tracking information. + + Returns: + Dict[BlockId, List[BlockId]]: A dictionary mapping source + block indices to lists of destination block indices for the + current copy-on-write operations. + """ + cows = dict(self._copy_on_writes) + self._copy_on_writes.clear() + return cows + + +def get_all_blocks_recursively(last_block: Block) -> List[Block]: + """Retrieves all the blocks in a sequence starting from the last block. + + This function recursively traverses the sequence of blocks in reverse order, + starting from the given last block, and returns a list of all the blocks in + the sequence. + + Args: + last_block (Block): The last block in the sequence. + + Returns: + List[Block]: A list of all the blocks in the sequence, in the order they + appear. + """ + + def recurse(block: Block, lst: List[Block]) -> None: + if block.prev_block is not None: + recurse(block.prev_block, lst) + lst.append(block) + + all_blocks = [] + recurse(last_block, all_blocks) + return all_blocks diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py new file mode 100644 index 00000000..3135e194 --- /dev/null +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -0,0 +1,206 @@ +from typing import Dict, List, Optional + +from vllm.core.block.interfaces import (Block, BlockAllocator, + DeviceAwareBlockAllocator) +from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.utils import Device + + +class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): + """A block allocator that can allocate blocks on both CPU and GPU memory. + + This class implements the `DeviceAwareBlockAllocator` interface and provides + functionality for allocating and managing blocks of memory on both CPU and + GPU devices. + + The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU + blocks, and allows for allocation, deallocation, forking, and swapping of + blocks across these memory pools. + """ + + @staticmethod + def create( + allocator_type: str, + num_gpu_blocks: int, + num_cpu_blocks: int, + block_size: int, + ) -> DeviceAwareBlockAllocator: + """Creates a CpuGpuBlockAllocator instance with the specified + configuration. + + This static method creates and returns a CpuGpuBlockAllocator instance + based on the provided parameters. It initializes the CPU and GPU block + allocators with the specified number of blocks, block size, and + allocator type. + + Args: + allocator_type (str): The type of block allocator to use for CPU + and GPU blocks. Currently supported values are "naive" and + "prefix_caching". + num_gpu_blocks (int): The number of blocks to allocate for GPU + memory. + num_cpu_blocks (int): The number of blocks to allocate for CPU + memory. + block_size (int): The size of each block in number of tokens. + + Returns: + DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the + specified configuration. + + Notes: + - The block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + """ + block_ids = list(range(num_gpu_blocks + num_cpu_blocks)) + gpu_block_ids = block_ids[:num_gpu_blocks] + cpu_block_ids = block_ids[num_gpu_blocks:] + + if allocator_type == "naive": + gpu_allocator = NaiveBlockAllocator( + create_block=NaiveBlock, + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator = NaiveBlockAllocator( + create_block=NaiveBlock, + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + elif allocator_type == "prefix_caching": + gpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + else: + raise ValueError(f"Unknown allocator type {allocator_type=}") + + return CpuGpuBlockAllocator( + cpu_block_allocator=cpu_allocator, + gpu_block_allocator=gpu_allocator, + ) + + def __init__( + self, + cpu_block_allocator: BlockAllocator, + gpu_block_allocator: BlockAllocator, + ): + assert not ( + cpu_block_allocator.all_block_ids + & gpu_block_allocator.all_block_ids + ), "cpu and gpu block allocators can't have intersection of block ids" + + self._allocators = { + Device.CPU: cpu_block_allocator, + Device.GPU: gpu_block_allocator, + } + + self._block_ids_to_allocator = {} + for _, allocator in self._allocators.items(): + for block_id in allocator.all_block_ids: + self._block_ids_to_allocator[block_id] = allocator + + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: + """Allocates a new mutable block on the specified device. + + Args: + prev_block (Optional[Block]): The previous block to in the sequence. + Used for prefix hashing. + device (Device): The device on which to allocate the new block. + + Returns: + Block: The newly allocated mutable block. + """ + return self._allocators[device].allocate_mutable(prev_block) + + def allocate_immutable(self, prev_block: Optional[Block], + token_ids: List[int], device: Device) -> Block: + """Allocates a new immutable block with the provided token IDs on the + specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + token_ids (List[int]): The list of token IDs to be stored in the new + block. + device (Device): The device on which to allocate the new block. + + Returns: + Block: The newly allocated immutable block containing the provided + token IDs. + """ + return self._allocators[device].allocate_immutable( + prev_block, token_ids) + + def free(self, block: Block) -> None: + """Frees the memory occupied by the given block. + + Args: + block (Block): The block to be freed. + """ + allocator = self._block_ids_to_allocator[block.block_id] + return allocator.free(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: A new list of blocks that shares the same memory as the + original sequence. + """ + allocator = self._block_ids_to_allocator[last_block.block_id] + return allocator.fork(last_block) + + def get_num_free_blocks(self, device: Device) -> int: + """Returns the number of free blocks available on the specified device. + + Args: + device (Device): The device for which to query the number of free + blocks. + + Returns: + int: The number of free blocks available on the specified device. + """ + return self._allocators[device].get_num_free_blocks() + + def clear_copy_on_writes(self) -> Dict[int, List[int]]: + """Clears the copy-on-write (CoW) state and returns the mapping of + source to destination block IDs. + + Returns: + Dict[int, List[int]]: A dictionary mapping source block IDs to lists + of destination block IDs. + """ + # CoW only supported on GPU + device = Device.GPU + return self._allocators[device].clear_copy_on_writes() + + def mark_blocks_as_computed(self) -> None: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_computed() + + def get_common_computed_block_ids( + self, seq_block_ids: List[List[int]]) -> List[int]: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].get_common_computed_block_ids( + seq_block_ids) + + def all_block_ids(self) -> frozenset[int]: + return frozenset(self._block_ids_to_allocator.keys()) diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py new file mode 100644 index 00000000..9f466566 --- /dev/null +++ b/vllm/core/block/interfaces.py @@ -0,0 +1,105 @@ +from abc import ABC, abstractmethod, abstractproperty +from typing import Dict, List, Optional, Protocol + +from vllm.utils import Device + + +class Block(ABC): + + @abstractmethod + def append_token_ids(self, token_ids: List[int]) -> None: + pass + + @abstractproperty + def block_id(self) -> Optional[int]: + pass + + @abstractproperty + def token_ids(self) -> List[int]: + pass + + @abstractproperty + def num_empty_slots(self) -> int: + pass + + @abstractproperty + def is_full(self) -> bool: + pass + + @abstractproperty + def prev_block(self) -> Optional["Block"]: + pass + + class Factory(Protocol): + + @abstractmethod + def __call__( + self, + prev_block: Optional["Block"], + token_ids: List[int], + block_size: int, + allocator: "BlockAllocator", + block_id: Optional[int] = None, + ) -> "Block": + pass + + +class BlockAllocator(ABC): + + @abstractmethod + def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + pass + + @abstractmethod + def allocate_immutable(self, prev_block: Optional[Block], + token_ids: List[int]) -> Block: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @abstractmethod + def get_num_free_blocks(self) -> int: + pass + + @abstractproperty + def all_block_ids(self) -> frozenset[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> Dict[int, List[int]]: + pass + + @abstractmethod + def mark_blocks_as_computed(self) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, seq_block_ids: List[List[int]]) -> List[int]: + pass + + class NoFreeBlocksError(ValueError): + pass + + +class DeviceAwareBlockAllocator(BlockAllocator): + + @abstractmethod + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: + pass + + @abstractmethod + def allocate_immutable(self, prev_block: Optional[Block], + token_ids: List[int], device: Device) -> Block: + pass + + @abstractmethod + def get_num_free_blocks(self, device: Device) -> int: + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py new file mode 100644 index 00000000..f8e9265b --- /dev/null +++ b/vllm/core/block/naive_block.py @@ -0,0 +1,275 @@ +from typing import Dict, Iterable, List, Optional, Set + +from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, + get_all_blocks_recursively) +from vllm.core.block.interfaces import Block, BlockAllocator + +BlockId = int +Refcount = int + + +class NaiveBlockAllocator(BlockAllocator): + """A simple block allocator that manages blocks of memory without prefix + caching. + + Args: + create_block (Block.Factory): A factory function for creating new + blocks. This is used when a NaiveBlockAllocator is composed within + a prefix caching allocator -- the naive block allocator must + construct prefix caching blocks (but shouldn't know anything else + about them). + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids (Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + def __init__( + self, + create_block: Block.Factory, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + ): + if block_ids is None: + block_ids = range(num_blocks) + + self._free_block_indices: Set[BlockId] = set(block_ids) + self._all_block_indices = frozenset(block_ids) + assert len(self._all_block_indices) == num_blocks + + self._refcounter = RefCounter( + all_block_indices=self._free_block_indices) + self._create_block = create_block + self._block_size = block_size + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly(), + allocator=self, + ) + + def allocate_immutable(self, prev_block: Optional[Block], + token_ids: List[int]) -> Block: + """Allocates a new immutable block with the given token IDs, linked to + the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + token_ids (List[int]): The token IDs to be stored in the new block. + + Returns: + Block: The newly allocated immutable block. + """ + block = self.allocate_mutable(prev_block=prev_block) + block.append_token_ids(token_ids) + return block + + def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + """Allocates a new mutable block, linked to the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + + Returns: + Block: The newly allocated mutable block. + """ + block_id = self._allocate_new_block_id() + return self._create_block( + prev_block=prev_block, + token_ids=[], + block_id=block_id, + block_size=self._block_size, + allocator=self, + ) + + def free(self, block: Block) -> None: + self._free_block_id(block.block_id) + + # Mark the block as having no allocation. + block.block_id = None + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks = [] + prev_block = None + for block in source_blocks: + + # Increment refcount for each block. + refcount = self._refcounter.incr(block.block_id) + assert refcount != 1, "can't fork free'd block" + + forked_blocks.append( + self._create_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_id=block.block_id, + block_size=self._block_size, + allocator=self, + )) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self) -> int: + return len(self._free_block_indices) + + def _allocate_new_block_id(self) -> BlockId: + if not self._free_block_indices: + raise BlockAllocator.NoFreeBlocksError() + + block_id = next(iter(self._free_block_indices)) + self._refcounter.incr(block_id) + self._free_block_indices.remove(block_id) + return block_id + + def _free_block_id(self, block_id: BlockId) -> None: + refcount = self._refcounter.decr(block_id) + if refcount == 0: + self._free_block_indices.add(block_id) + + @property + def refcounter(self): + return self._refcounter + + @property + def all_block_ids(self): + return self._all_block_indices + + def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + Optional[BlockId]: The block index of the new block if a copy-on + -write operation was performed, or the original block index if + no copy-on-write was necessary. + """ + return self._cow_tracker.cow_block_if_not_appendable(block) + + def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + Dict[BlockId, List[BlockId]]: A dictionary mapping source + block indices to lists of destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_computed(self) -> None: + """Mark blocks as computed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def get_common_computed_block_ids( + self, seq_block_ids: List[List[int]]) -> List[int]: + """Determine blocks that can be skipped in prefill. + + Since the naive allocator does not support prefix caching, always return + an empty list. + """ + return [] + + +class NaiveBlock(Block): + """An implementation of the Block class that does not support prefix + caching. + + The NaiveBlock class represents a block of token IDs with a fixed size. It + provides methods for appending token IDs to the block and manages copy-on + -write operations when necessary. + + Args: + prev_block (Block): The previous block in the sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + allocator (BlockAllocator): The block allocator associated with this + block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None, which means no allocation has been + made. + _cow_target (Optional[Block], optional): The copy-on-write target block. + If not provided, it defaults to self. + """ + + def __init__(self, + prev_block: Block, + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + _cow_target: Optional[Block] = None): + self._token_ids = [] + self._block_size = block_size + self._prev_block = prev_block + self._block_id = block_id + self._allocator = allocator + self._cow_target = _cow_target if _cow_target is not None else self + + self._append_token_ids_no_cow(token_ids) + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block, instructing the allocator + to perform a copy-on-write if necessary. + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + self._append_token_ids_no_cow(token_ids) + + if self._block_id is not None: + self._block_id = (self._allocator.cow_block_if_not_appendable( + self._cow_target)) + + def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: + assert self.num_empty_slots >= len(token_ids) + self._token_ids.extend(token_ids) + + @property + def block_id(self) -> Optional[int]: + return self._block_id + + @block_id.setter + def block_id(self, value: Optional[int]) -> None: + self._block_id = value + + @property + def is_full(self) -> bool: + return self.num_empty_slots == 0 + + @property + def num_empty_slots(self) -> int: + return self._block_size - len(self._token_ids) + + @property + def token_ids(self) -> List[int]: + return self._token_ids + + def block_size(self) -> int: + return self._block_size + + @property + def prev_block(self) -> Optional["Block"]: + return self._prev_block diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py new file mode 100644 index 00000000..6aa75a8a --- /dev/null +++ b/vllm/core/block/prefix_caching_block.py @@ -0,0 +1,472 @@ +"""Token blocks.""" +from itertools import takewhile +from os.path import commonprefix +from typing import Dict, Iterable, List, Optional + +from vllm.core.block.common import (CopyOnWriteTracker, + get_all_blocks_recursively) +from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator + +PrefixHash = int +BlockId = int + + +class PrefixCachingBlockAllocator(BlockAllocator): + """A block allocator that implements prefix caching. + + The PrefixCachingBlockAllocator maintains a cache of blocks based on their + content hash. It reuses blocks with the same content hash to avoid redundant + memory allocation. The allocator also supports copy-on-write operations. + + Args: + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids(Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + # TODO last access time / evictor integration + + def __init__( + self, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + ): + # A mapping of prefix hash to block index. All blocks which have a + # prefix hash will be in this dict, even if they have refcount 0. + self._cached_blocks: Dict[PrefixHash, BlockId] = {} + + # A mapping of prefix hash to block index. All blocks which have a + # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset + # of self._cached_blocks. + self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {} + + # An allocator for blocks that do not have prefix hashes. + self._hashless_allocator = NaiveBlockAllocator( + create_block=self._create_block, + num_blocks=num_blocks, + block_size=block_size, + block_ids=block_ids, + ) + + self._block_size = block_size + + # We share the refcounter between allocators. This allows us to promote + # blocks originally allocated in the hashless allocator to immutable + # blocks. + self._refcounter = self._hashless_allocator.refcounter + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly(), + allocator=self, + ) + + # Implements Block.Factory. + def _create_block( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + ) -> Block: + # Bind block to self. + allocator = self + + return PrefixCachingBlock( + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + prefix_caching_allocator=allocator, + ) + + def allocate_immutable(self, prev_block: Optional[Block], + token_ids: List[int]) -> Block: + """Allocates an immutable block with the given token IDs, reusing cached + blocks if possible. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + token_ids (List[int]): The token IDs to be stored in the block. + + Returns: + Block: The allocated immutable block. + """ + assert_prefix_caching_block_or_none(prev_block) + + block = self._create_block( + prev_block=prev_block, + token_ids=token_ids, + block_size=self._block_size, + allocator=self, + ) + assert block.content_hash is not None + + cached_block_id = self._cached_blocks.get(block.content_hash, None) + if cached_block_id is not None: + block.block_id = cached_block_id + self._incr_refcount_cached_block(block.content_hash, + block.block_id) + return block + + block = self.allocate_mutable(prev_block) + block.append_token_ids(token_ids) + assert block.content_hash is not None + # TODO computed bit + + return block + + def allocate_mutable(self, prev_block: Block) -> Block: + """Allocates a mutable block. If there are no free blocks, this will + evict unused cached blocks. + + Args: + prev_block (Block): The previous block in the sequence. + + Returns: + Block: The allocated mutable block. + """ + assert_prefix_caching_block_or_none(prev_block) + + try: + return self._hashless_allocator.allocate_mutable( + prev_block=prev_block) + except BlockAllocator.NoFreeBlocksError: + # We must check the unused cached blocks before raising OOM. + pass + + if self._unused_cached_blocks: + # TODO policy for selecting block to remove + content_hash_to_evict = next(iter(self._unused_cached_blocks)) + + # Clear content hash mapping; the block will be overwritten. + del self._cached_blocks[content_hash_to_evict] + + block_id = self._unused_cached_blocks.pop(content_hash_to_evict) + refcount = self._refcounter.incr(block_id) + assert refcount == 1 + block = self._create_block( + prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + allocator=self, + block_id=block_id, + ) + assert block.content_hash is None + return block + + # No block available in hashless allocator, nor in unused cache blocks. + raise BlockAllocator.NoFreeBlocksError() + + def _incr_refcount_cached_block(self, content_hash: int, + block_id: BlockId) -> None: + refcount = self._refcounter.incr(block_id) + if refcount == 1: + assert content_hash in self._unused_cached_blocks + del self._unused_cached_blocks[content_hash] + + def free(self, block: Block) -> None: + """Decrement the refcount of the block. If the decremented refcount is + zero, store the block in the freelist. + + If the block has a content hash (meaning it is immutable), then we will + keep the block around in case future allocations require it. + """ + assert (block.block_id + is not None), "freeing unallocated block is undefined" + + self._free_block_id_for_block(block.block_id, block) + block.block_id = None + + def _free_block_id_for_block(self, block_id: BlockId, + block: Block) -> None: + assert isinstance(block, PrefixCachingBlock) + + if block.content_hash is None: + return self._hashless_allocator.free(block) + + refcount = self._refcounter.decr(block_id) + + # If no longer used, add the block to the unused cached blocks. + if refcount == 0: + assert block.content_hash not in self._unused_cached_blocks + assert block.content_hash in self._cached_blocks + self._unused_cached_blocks[block.content_hash] = block_id + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks = [] + prev_block = None + for block in source_blocks: + refcount = self._refcounter.incr(block.block_id) + assert refcount != 1, "can't fork free'd block" + + forked_blocks.append( + self._create_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_id=block.block_id, + block_size=self._block_size, + allocator=self, + )) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self) -> int: + # The number of free blocks is the number of hashless free blocks + # plus the number of hashful blocks that are unused. + return self._hashless_allocator.get_num_free_blocks() + len( + self._unused_cached_blocks) + + @property + def all_block_ids(self) -> frozenset[int]: + return self._hashless_allocator.all_block_ids + + def promote_to_immutable_block(self, + block: "PrefixCachingBlock") -> BlockId: + """Once a mutable block is full, it can be promoted to an immutable + block. This means that its content can be referenced by future blocks + having the same prefix. + + Note that if we already have a cached block with the same content, we + will replace the newly-promoted block's mapping with the existing cached + block. + + Args: + block (PrefixCachingBlock): The mutable block to be promoted. + + Returns: + BlockId: Either the original block index, or the block index of + the previously cached block matching the same content. + """ + assert block.content_hash is not None + assert block.block_id is not None + assert self._refcounter.get(block.block_id) > 0 + + # If the content hash does not have a corresponding cached block, + # set this block as the cached block. + if block.content_hash not in self._cached_blocks: + self._cached_blocks[block.content_hash] = block.block_id + else: + self._free_block_id_for_block(block.block_id, block) + self._incr_refcount_cached_block( + block.content_hash, self._cached_blocks[block.content_hash]) + + return self._cached_blocks[block.content_hash] + + def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + Optional[BlockId]: The block index of the new block if a copy-on + -write operation was performed, or the original block index if + no copy-on-write was necessary. + """ + return self._cow_tracker.cow_block_if_not_appendable(block) + + def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + Dict[BlockId, List[BlockId]]: A dictionary mapping source + block indices to lists of destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_computed(self) -> None: + """Mark blocks as computed, used in prefix caching.""" + # TODO Track computed blocks. + pass + + def get_common_computed_block_ids( + self, seq_block_ids: List[List[int]]) -> List[int]: + """Return the block ids that are common for a given sequence group. + + Used in prefill (can skip prefill of some blocks). + """ + + # TODO: Track computed blocks. + computed = lambda block_id: False + + # NOTE We exclude the last block to avoid the case where the entire + # prompt is cached. This would cause erroneous behavior in model + # runner. + ids_list = [ + takewhile(lambda block_id: computed(block_id), seq[:-1]) + for seq in seq_block_ids + ] + return commonprefix([ids for ids in ids_list if ids != []]) + + +class PrefixCachingBlock(Block): + """A block implementation that supports prefix caching. + + The PrefixCachingBlock class represents a block of token IDs with prefix + caching capabilities. It wraps a NaiveBlock internally and provides + additional functionality for content hashing and promoting immutable blocks + with the prefix caching allocator. + + Args: + prev_block (Optional[PrefixCachingBlock]): The previous block in the + sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix + caching block allocator associated with this block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None. + """ + + def __init__( + self, + prev_block: Optional["PrefixCachingBlock"], + token_ids: List[int], + block_size: int, + prefix_caching_allocator: PrefixCachingBlockAllocator, + block_id: Optional[int] = None, + ): + assert_prefix_caching_block_or_none(prev_block) + + self._prev_block = prev_block + self._cached_content_hash: Optional[int] = None + self._prefix_caching_allocator = prefix_caching_allocator + + self._block = NaiveBlock( + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=prefix_caching_allocator, + _cow_target=self, + ) + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block and registers the block as + immutable if the block becomes full. + + Internally, the naive block handles CoW. + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + assert token_ids + + # naive block handles CoW. + self._block.append_token_ids(token_ids) + + # If the content hash is present, then the block can be made immutable. + # Register ourselves with the allocator, potentially replacing the + # physical block index. + if self.content_hash is not None: + self.block_id = (self._prefix_caching_allocator. + promote_to_immutable_block(self)) + + @property + def block_id(self) -> Optional[int]: + return self._block.block_id + + @block_id.setter + def block_id(self, value) -> None: + self._block.block_id = value + + @property + def is_full(self) -> bool: + return self._block.is_full + + @property + def num_empty_slots(self) -> int: + return self._block.num_empty_slots + + @property + def block_size(self) -> int: + return self._block.block_size + + @property + def token_ids(self) -> List[int]: + return self._block.token_ids + + @property + def prev_block(self) -> Optional[Block]: + return self._prev_block + + @property + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined. + + For the content-based hash to be defined, the current block must be + full. + """ + + # If the hash is already computed, return it. + if self._cached_content_hash is not None: + return self._cached_content_hash + + # We cannot compute a hash for the current block because it is not full. + if not self.is_full: + return None + + is_first_block = self._prev_block is None + prev_block_hash = (None if is_first_block else + self._prev_block.content_hash) + + # Previous block exists but does not yet have a hash. + # Return no hash in this case. + if prev_block_hash is None and not is_first_block: + return None + + self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block, + prev_block_hash, + cur_block_token_ids=self.token_ids) + return self._cached_content_hash + + @staticmethod + def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], + cur_block_token_ids: List[int]) -> int: + """Computes a hash value corresponding to the contents of a block and + the contents of the preceding block(s). The hash value is used for + prefix caching. + + NOTE: Content-based hashing does not yet support LoRA. + + Parameters: + - is_first_block (bool): A flag indicating if the block is the first in + the sequence. + - prev_block_hash (Optional[int]): The hash of the previous block. None + if this is the first block. + - cur_block_token_ids (List[int]): A list of token ids in the current + block. The current block is assumed to be full. + + Returns: + - int: The computed hash value for the block. + """ + assert (prev_block_hash is None) == is_first_block + return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) + + +def assert_prefix_caching_block_or_none(block: Optional[Block]): + if block is None: + return + assert isinstance(block, PrefixCachingBlock) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager_v1.py similarity index 96% rename from vllm/core/block_manager.py rename to vllm/core/block_manager_v1.py index c6fca413..c5c8d0a0 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager_v1.py @@ -1,5 +1,4 @@ """A block manager that manages token blocks.""" -import enum from abc import ABC, abstractmethod from itertools import count, takewhile from os.path import commonprefix @@ -7,6 +6,7 @@ from typing import Dict, List, Optional, Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device @@ -196,21 +196,7 @@ class UncachedBlockAllocator(BlockAllocatorBase): "Invalid codepath for uncached block allocator.") -class AllocStatus(enum.Enum): - """Result for BlockSpaceManager.can_allocate - - 1. Ok: seq_group can be allocated now. - 2. Later: seq_group cannot be allocated. - The capacity of allocator is larger than seq_group required. - 3. Never: seq_group can never be allocated. - The seq_group is too large to allocated in GPU. - """ - OK = enum.auto() - LATER = enum.auto() - NEVER = enum.auto() - - -class BlockSpaceManager: +class BlockSpaceManagerV1(BlockSpaceManager): """Manages the mapping between logical and physical token blocks.""" def __init__( @@ -355,6 +341,11 @@ class BlockSpaceManager: self, seq: Sequence, ) -> PhysicalTokenBlock: + # Called before a new block is appended. + # This is in charge of allocating a new physical block (to be appended). + + # None if the last block is not full. Otherwise, we set it to the + # content hash. if not self.enable_caching: return self.gpu_allocator.allocate() block_hash: Optional[int] = None @@ -362,7 +353,14 @@ class BlockSpaceManager: block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) num_hashed_tokens = seq.num_hashed_tokens_of_block( len(seq.logical_token_blocks) - 1) + + # num_hashed_tokens is used to compute future hashes + # (e.g. in the hashing function, it is used to ask the sequence for + # prefix tokens) new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens) + + # If the block has is None, then the block is not full. + # If the block is not full, then we expect it to have a refcount of 1. if block_hash is None: assert new_block.ref_count == 1 return new_block @@ -576,16 +574,16 @@ class BlockSpaceManager: for b in takewhile(lambda b: b.computed, block_table[:-1]) ] - def get_common_computed_block_ids(self, - seq_group: SequenceGroup) -> List[int]: + def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + """Return the block ids that are common for a given sequence group. + + Used in prefill (can skip prefill of some blocks). + """ # Can return non-empty result only with prefix caching enabled. if not self.enable_caching: return [] - ids_list = [ - self.get_all_computed_blocks(seq) - for seq in iter(seq_group.seqs_dict.values()) - ] + ids_list = [self.get_all_computed_blocks(seq) for seq in seqs] return commonprefix([ids for ids in ids_list if ids != []]) def mark_blocks_as_computed(self, seq_group: SequenceGroup): diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py new file mode 100644 index 00000000..37c70073 --- /dev/null +++ b/vllm/core/block_manager_v2.py @@ -0,0 +1,210 @@ +"""A block manager that manages token blocks.""" +from typing import Dict, List, Optional, Tuple + +from vllm.core.block.block_table import BlockTable +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device + +SeqId = int + + +class BlockSpaceManagerV2(BlockSpaceManager): + """BlockSpaceManager which manages the allocation of KV cache. + + It owns responsibility for allocation, swapping, allocating memory for + autoregressively-generated tokens, and other advanced features such as + prefix caching, forking/copy-on-write, and sliding-window memory allocation. + + The current implementation is partial; in particular prefix caching and + sliding-window are not feature complete. This class implements the design + described in https://github.com/vllm-project/vllm/pull/3492. + + Args: + block_size (int): The size of each memory block. + num_gpu_blocks (int): The number of memory blocks allocated on GPU. + num_cpu_blocks (int): The number of memory blocks allocated on CPU. + watermark (float, optional): The threshold used for memory swapping. + Defaults to 0.01. + sliding_window (Optional[int], optional): The size of the sliding + window. Defaults to None. + enable_caching (bool, optional): Flag indicating whether caching is + enabled. Defaults to False. + """ + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + enable_caching: bool = False, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + assert sliding_window is None, "Sliding window not yet supported" + + self.block_sliding_window = None + + self.watermark = watermark + assert watermark >= 0.0 + + assert not enable_caching, "Prefix caching not yet supported" + self.enable_caching = enable_caching + + self.watermark_blocks = int(watermark * num_gpu_blocks) + + self.block_allocator = CpuGpuBlockAllocator.create( + # Currently, only naive blocks are supported (no prefix caching). + allocator_type="naive", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + self.block_tables: Dict[SeqId, BlockTable] = {} + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + + num_required_blocks = BlockTable.get_num_required_blocks( + seq.get_token_ids(), + block_size=self.block_size, + ) + + assert self.block_sliding_window is None + if self.block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, + self.block_sliding_window) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + device=Device.GPU) + + # Use watermark to avoid frequent cache eviction. + if (self.num_total_gpu_blocks - num_required_blocks < + self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def allocate(self, seq_group: SequenceGroup) -> None: + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert not (set(seq.seq_id for seq in waiting_seqs) + & self.block_tables.keys()), "block table already exists" + + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + seq = waiting_seqs[0] + + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + ) + assert self.block_sliding_window is None + block_table.allocate(seq.get_token_ids()) + self.block_tables[seq.seq_id] = block_table + + # Assign the block table for each sequence. + for seq in waiting_seqs[1:]: + self.block_tables[seq.seq_id] = block_table.fork() + + def can_append_slot(self, seq_group: SequenceGroup) -> bool: + # Simple heuristic: If there is at least one free block + # for each sequence, we can append. + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + Device.GPU) + num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) + return num_seqs <= num_free_gpu_blocks + + def append_slot( + self, + seq: Sequence, + ) -> Optional[Tuple[int, int]]: + + block_table = self.block_tables[seq.seq_id] + + # Get unseen token ids. + num_full_slots = block_table.num_full_slots + unseen_token_ids = seq.get_token_ids()[num_full_slots:] + assert unseen_token_ids + + block_table.append_token_ids(unseen_token_ids) + + # Return any copy-on-writes. + _ = self.block_allocator.clear_copy_on_writes() + + # TODO extend append_slot interface to append_slots + # @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250 + + return None + + def free(self, seq: Sequence) -> None: + if seq.seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + self.block_tables[seq.seq_id].free() + del self.block_tables[seq.seq_id] + + def get_block_table(self, seq: Sequence) -> List[int]: + assert seq.seq_id in self.block_tables + block_ids = self.block_tables[seq.seq_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids + + def access_all_blocks_in_seq(self, seq, now): + # TODO add prefix caching support. + # Tracked here https://github.com/vllm-project/vllm/issues/3667 + pass + + def mark_blocks_as_computed(self, seq_group: SequenceGroup): + # We ignore the sequence group as its not necessary. After the batch is + # formed by the scheduler, we do not need to mark blocks from individual + # sequence groups as computed -- all blocks in the batch can be marked + # as computed. + self.block_allocator.mark_blocks_as_computed() + + def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + """Determine which blocks for which we skip prefill. + + With prefix caching we can skip prefill for previously-generated blocks. + Currently, the attention implementation only supports skipping cached + blocks if they are a contiguous prefix of cached blocks. + + This method determines which blocks can be safely skipped for all + sequences in the sequence group. + """ + seq_block_ids = [ + self.block_tables[seq.seq_id].physical_block_ids for seq in seqs + ] + return self.block_allocator.get_common_computed_block_ids( + seq_block_ids) + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + src_block_table = self.block_tables[parent_seq.seq_id] + self.block_tables[child_seq.seq_id] = src_block_table.fork() + + def can_swap_in(self, seq_group: SequenceGroup) -> bool: + return False + + def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + raise NotImplementedError + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return False + + def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + raise NotImplementedError + + def get_num_free_gpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.GPU) + + def get_num_free_cpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.CPU) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py new file mode 100644 index 00000000..48524de0 --- /dev/null +++ b/vllm/core/interfaces.py @@ -0,0 +1,107 @@ +import enum +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple + +from vllm.sequence import Sequence, SequenceGroup + + +class AllocStatus(enum.Enum): + """Result for BlockSpaceManager.can_allocate + + 1. Ok: seq_group can be allocated now. + 2. Later: seq_group cannot be allocated. + The capacity of allocator is larger than seq_group required. + 3. Never: seq_group can never be allocated. + The seq_group is too large to allocated in GPU. + """ + OK = enum.auto() + LATER = enum.auto() + NEVER = enum.auto() + + +class BlockSpaceManager(ABC): + + @staticmethod + def get_block_space_manager_class(version: str): + version = version.lower() + + if version == "v1": + from vllm.core.block_manager_v1 import BlockSpaceManagerV1 + return BlockSpaceManagerV1 + + if version == "v2": + from vllm.core.block_manager_v2 import BlockSpaceManagerV2 + return BlockSpaceManagerV2 + + raise ValueError(f"Unknown version {version=}") + + @abstractmethod + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + pass + + @abstractmethod + def allocate(self, seq_group: SequenceGroup) -> None: + pass + + @abstractmethod + def can_append_slot(self, seq_group: SequenceGroup) -> bool: + pass + + @abstractmethod + def append_slot( + self, + seq: Sequence, + ) -> Optional[Tuple[int, int]]: + pass + + @abstractmethod + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + @abstractmethod + def can_swap_in(self, seq_group: SequenceGroup) -> bool: + pass + + @abstractmethod + def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + pass + + @abstractmethod + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + pass + + @abstractmethod + def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + pass + + @abstractmethod + def free(self, seq: Sequence) -> None: + pass + + @abstractmethod + def get_block_table(self, seq: Sequence) -> List[int]: + pass + + @abstractmethod + def get_num_free_gpu_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_cpu_blocks(self) -> int: + pass + + @abstractmethod + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + pass + + @abstractmethod + def mark_blocks_as_computed(self, seq_group: SequenceGroup): + pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 899816b6..85c2fdf7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,7 +4,7 @@ from collections import deque from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -from vllm.core.block_manager import AllocStatus, BlockSpaceManager +from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -88,8 +88,13 @@ class Scheduler: # Instantiate the scheduling policy. self.policy = PolicyFactory.get_policy(policy_name="fcfs") + + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( + version="v2" if self.scheduler_config. + use_v2_block_manager else "v1") + # Create the block space manager. - self.block_manager = BlockSpaceManager( + self.block_manager = BlockSpaceManagerImpl( block_size=self.cache_config.block_size, num_gpu_blocks=self.cache_config.num_gpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks, @@ -378,6 +383,10 @@ class Scheduler: block_tables[seq_id] = self.block_manager.get_block_table(seq) self.block_manager.access_all_blocks_in_seq(seq, now) + common_computed_block_nums = ( + self.block_manager.get_common_computed_block_ids( + seq_group.get_seqs(status=SequenceStatus.RUNNING))) + seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, is_prompt=scheduler_outputs.prompt_run, @@ -385,8 +394,7 @@ class Scheduler: sampling_params=seq_group.sampling_params, block_tables=block_tables, lora_request=seq_group.lora_request, - computed_block_nums=self.block_manager. - get_common_computed_block_ids(seq_group), + computed_block_nums=common_computed_block_nums, state=seq_group.state, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. @@ -396,6 +404,14 @@ class Scheduler: if scheduler_outputs.prompt_run else None, ) seq_group_metadata_list.append(seq_group_metadata) + + # Now that the batch has been created, we can assume all blocks in the + # batch will have been computed before the next scheduling invocation. + # This is because the engine assumes that a failure in model execution + # will crash the vLLM instance / will not retry. + for seq_group in scheduler_outputs.scheduled_seq_groups: + self.block_manager.mark_blocks_as_computed(seq_group) + return seq_group_metadata_list, scheduler_outputs def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: @@ -503,9 +519,6 @@ class Scheduler: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED - def mark_blocks_as_computed(self, seq_group: SequenceGroup): - self.block_manager.mark_blocks_as_computed(seq_group) - def _passed_delay(self, now: float) -> bool: if self.prev_prompt: self.last_prompt_latency = now - self.prev_time diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6dcd60a1..09f90d10 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -28,6 +28,7 @@ class EngineArgs: max_parallel_loading_workers: Optional[int] = None block_size: int = 16 enable_prefix_caching: bool = False + use_v2_block_manager: bool = False swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None @@ -52,6 +53,9 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' ray_workers_use_nsight: bool = False + + forced_num_gpu_blocks: Optional[int] = None + # Related to Vision-language models such as llava image_input_type: Optional[str] = None image_token_id: Optional[int] = None @@ -194,6 +198,9 @@ class EngineArgs: parser.add_argument('--enable-prefix-caching', action='store_true', help='Enables automatic prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') parser.add_argument('--seed', type=int, @@ -210,6 +217,12 @@ class EngineArgs: help='the fraction of GPU memory to be used for ' 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') + parser.add_argument( + '--forced-num-gpu-blocks', + type=int, + default=None, + help='If specified, ignore GPU profiling result and use this number' + 'of GPU blocks. Used for testing preemption.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -369,6 +382,7 @@ class EngineArgs: cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, + self.forced_num_gpu_blocks, model_config.get_sliding_window(), self.enable_prefix_caching) parallel_config = ParallelConfig( @@ -383,6 +397,7 @@ class EngineArgs: scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, + self.use_v2_block_manager, self.scheduler_delay_factor) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c688397..649cd040 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -553,12 +553,6 @@ class LLMEngine: # Update the scheduled sequence groups with the model outputs. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - # If prefix caching is enabled, mark all blocks in the sequence groups - # as completed so that future requests don't attempt to recompute them - if self.cache_config.enable_prefix_caching: - for seq_group in scheduled_seq_groups: - self.scheduler.mark_blocks_as_computed(seq_group) - for seq_group, outputs in zip(scheduled_seq_groups, output): self._process_sequence_group_outputs(seq_group, outputs) diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 90c38824..adbc4cb7 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -85,6 +85,12 @@ class GPUExecutor(ExecutorBase): cache_dtype=self.cache_config.cache_dtype, )) + if self.cache_config.forced_num_gpu_blocks is not None: + forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks + logger.info(f"Replacing profiled {num_gpu_blocks=} with " + f"{forced_num_gpu_blocks=}") + num_gpu_blocks = forced_num_gpu_blocks + logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index d8288832..4ac72bb0 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -232,6 +232,13 @@ class RayGPUExecutor(ExecutorBase): # operators can be applied to all workers. num_gpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) + + if self.cache_config.forced_num_gpu_blocks is not None: + forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks + logger.info(f"Replacing profiled {num_gpu_blocks=} with " + f"{forced_num_gpu_blocks=}") + num_gpu_blocks = forced_num_gpu_blocks + logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") diff --git a/vllm/sequence.py b/vllm/sequence.py index b019b5bf..8292e207 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -196,6 +196,8 @@ class Sequence: return self.lora_request.lora_int_id if self.lora_request else 0 def hash_of_block(self, logical_idx: int) -> int: + # TODO This can produce incorrect hash when block size > prompt size + # Compute the number of tokens in the sequence # TODO: The current hashing function is O(L^2). We should optimize # this in the future. diff --git a/vllm/utils.py b/vllm/utils.py index 83d94f28..93fff4ff 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -227,6 +227,16 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) +def chunk_list(lst, chunk_size): + """Yield successive chunk_size chunks from lst.""" + return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] + + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + + @lru_cache(maxsize=None) def get_nvcc_cuda_version() -> Optional[Version]: cuda_home = os.environ.get('CUDA_HOME')